dhineshkmar's picture
Upload 4 files
e79cc54 verified
import os
import torch
from flask import Flask, render_template, request, redirect, url_for
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
from werkzeug.utils import secure_filename
app = Flask(__name__, static_folder="static", template_folder="templates")
# Load feature extractor and model
feature_extractor = ViTFeatureExtractor.from_pretrained(
'google/vit-base-patch16-224-in21k'
)
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k', num_labels=4
)
# Path to custom trained weights
model_path = r'C:\Users\Dhinesh kumar A\Downloads\vit-main\vit-main\vit_multiple_sclerosis_final.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Class mapping
class_mapping = {
0: 'Control-Axial',
1: 'Control-Sagittal',
2: 'MS-Axial',
3: 'MS-Sagittal'
}
# Upload folder
UPLOAD_FOLDER = os.path.join(app.static_folder, "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['GET', 'POST'])
def upload_and_predict():
if request.method == 'POST':
if 'file' not in request.files or request.files['file'].filename == '':
return redirect(request.url)
file = request.files['file']
filename = secure_filename(file.filename)
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
# Preprocess image
image = Image.open(filepath).convert('RGB')
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values']
# Inference
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = class_mapping[predicted_class_idx]
# Serve uploaded image via /static/uploads
image_url = url_for('static', filename=f'uploads/{filename}')
return render_template(
'result.html',
prediction=predicted_class,
uploaded_image=image_url
)
return render_template('upload.html')
if __name__ == '__main__':
app.run(debug=True)