File size: 5,052 Bytes
f1620d5
e88f4f2
f1620d5
 
 
 
 
e88f4f2
 
 
 
 
f1620d5
 
 
 
 
 
e88f4f2
 
f1620d5
 
 
 
 
 
 
 
e88f4f2
 
 
f1620d5
 
 
 
 
 
e88f4f2
 
 
 
 
 
 
 
 
 
 
 
f1620d5
 
 
 
 
 
 
e88f4f2
 
 
 
 
 
f1620d5
 
e88f4f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1620d5
 
e04e17b
e88f4f2
e04e17b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import traceback
import torch
from flask import Flask, render_template, request
from PIL import Image
import numpy as np
import cv2
from werkzeug.utils import secure_filename

# Import from gradcam (uses safe import that won't crash on missing model)
# gradcam.py must export: GradCAM, model, classes, get_model
from gradcam import GradCAM, model, classes, get_model

from torchvision import transforms

app = Flask(__name__)
UPLOAD_FOLDER = "static/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
ALLOWED_EXT = {"png", "jpg", "jpeg", "bmp"}

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

def allowed_file(filename):
    return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXT

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    # Ensure model is loaded (try lazy load if module-level model was None)
    global model
    if model is None:
        model = get_model(reload=True)
    if model is None:
        # Friendly error page — you can make a nicer HTML template if you want
        err = (
            "Model is not available. Please upload a valid `model.pth` to the Space "
            "or check the application logs for details."
        )
        return render_template('error.html', error_message=err), 500

    if 'image' not in request.files:
        return "No image uploaded", 400

    file = request.files['image']
    if file.filename == '':
        return "No selected image", 400

    if not allowed_file(file.filename):
        return "Unsupported file type", 400

    # sanitize filename and save
    filename = secure_filename(file.filename)
    img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(img_path)

    try:
        # Load image and preprocess
        image = Image.open(img_path).convert("RGB")
        input_tensor = transform(image).unsqueeze(0)

        # Move input to the same device as model
        try:
            model_device = next(model.parameters()).device
        except Exception:
            model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        input_tensor = input_tensor.to(model_device)

        # Predict
        with torch.no_grad():
            output = model(input_tensor)
            pred_idx = int(torch.argmax(output, dim=1).item())
            confidence = float(torch.softmax(output, dim=1)[0][pred_idx].item())

        # Grad-CAM: choose a sensible target layer
        # For DenseNet, a typical target is model.features.denseblock4 or the final features element
        try:
            # prefer denseblock4 if present
            target_layer = getattr(model.features, "denseblock4", None)
            if target_layer is None:
                # fallback to last features module
                target_layer = model.features[-1]
        except Exception:
            target_layer = model.features

        gradcam = GradCAM(model, target_layer=target_layer)

        # Note: GradCAM.generate returns (cam_resized, probs, pred_idx) in the robust gradcam.py
        cam_map, probs, returned_idx = gradcam.generate(input_tensor, class_idx=pred_idx)
        # cam_map is a numpy array normalized 0..1 with shape (H, W)

        # Prepare overlay image
        # Resize original image to 224x224 and convert to numpy RGB
        orig_np = np.array(image.resize((224, 224))).astype(np.uint8)

        # Convert cam_map (0..1) to heatmap (0..255) then to colored map
        heatmap = np.uint8(255 * cam_map)
        heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

        # Blend overlay (weights can be tuned)
        overlay = (0.6 * orig_np.astype(np.float32) + 0.4 * heatmap_color.astype(np.float32))
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)

        # Save overlay using a distinct filename
        cam_filename = f"cam_{filename}"
        cam_path = os.path.join(app.config['UPLOAD_FOLDER'], cam_filename)
        # cv2.imwrite expects BGR, convert overlay RGB->BGR
        cv2.imwrite(cam_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))

        return render_template(
            'result.html',
            prediction=classes[pred_idx] if pred_idx < len(classes) else str(pred_idx),
            confidence=f"{confidence * 100:.2f}%",
            uploaded_image=filename,
            cam_image=cam_filename
        )

    except Exception as e:
        # Log trace for debugging
        tb = traceback.format_exc()
        print("Error during prediction:", e)
        print(tb)
        return render_template('error.html', error_message=str(e)), 500

if __name__ == '__main__':
    port = int(os.environ.get("PORT", 7860))
    # In production debug should be False
    app.run(host="0.0.0.0", port=port, debug=True)