File size: 2,401 Bytes
e79cc54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
from flask import Flask, render_template, request, redirect, url_for
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
from werkzeug.utils import secure_filename

app = Flask(__name__, static_folder="static", template_folder="templates")

# Load feature extractor and model
feature_extractor = ViTFeatureExtractor.from_pretrained(
    'google/vit-base-patch16-224-in21k'
)
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k', num_labels=4
)

# Path to custom trained weights
model_path = r'C:\Users\Dhinesh kumar A\Downloads\vit-main\vit-main\vit_multiple_sclerosis_final.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Class mapping
class_mapping = {
    0: 'Control-Axial',
    1: 'Control-Sagittal',
    2: 'MS-Axial',
    3: 'MS-Sagittal'
}

# Upload folder
UPLOAD_FOLDER = os.path.join(app.static_folder, "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/predict', methods=['GET', 'POST'])
def upload_and_predict():
    if request.method == 'POST':
        if 'file' not in request.files or request.files['file'].filename == '':
            return redirect(request.url)

        file = request.files['file']
        filename = secure_filename(file.filename)
        filepath = os.path.join(UPLOAD_FOLDER, filename)
        file.save(filepath)

        # Preprocess image
        image = Image.open(filepath).convert('RGB')
        inputs = feature_extractor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values']

        # Inference
        with torch.no_grad():
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            predicted_class_idx = logits.argmax(-1).item()
            predicted_class = class_mapping[predicted_class_idx]

        # Serve uploaded image via /static/uploads
        image_url = url_for('static', filename=f'uploads/{filename}')

        return render_template(
            'result.html',
            prediction=predicted_class,
            uploaded_image=image_url
        )

    return render_template('upload.html')


if __name__ == '__main__':
    app.run(debug=True)