File size: 5,052 Bytes
f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e88f4f2 f1620d5 e04e17b e88f4f2 e04e17b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import traceback
import torch
from flask import Flask, render_template, request
from PIL import Image
import numpy as np
import cv2
from werkzeug.utils import secure_filename
# Import from gradcam (uses safe import that won't crash on missing model)
# gradcam.py must export: GradCAM, model, classes, get_model
from gradcam import GradCAM, model, classes, get_model
from torchvision import transforms
app = Flask(__name__)
UPLOAD_FOLDER = "static/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
ALLOWED_EXT = {"png", "jpg", "jpeg", "bmp"}
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXT
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
# Ensure model is loaded (try lazy load if module-level model was None)
global model
if model is None:
model = get_model(reload=True)
if model is None:
# Friendly error page — you can make a nicer HTML template if you want
err = (
"Model is not available. Please upload a valid `model.pth` to the Space "
"or check the application logs for details."
)
return render_template('error.html', error_message=err), 500
if 'image' not in request.files:
return "No image uploaded", 400
file = request.files['image']
if file.filename == '':
return "No selected image", 400
if not allowed_file(file.filename):
return "Unsupported file type", 400
# sanitize filename and save
filename = secure_filename(file.filename)
img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(img_path)
try:
# Load image and preprocess
image = Image.open(img_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
# Move input to the same device as model
try:
model_device = next(model.parameters()).device
except Exception:
model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_tensor = input_tensor.to(model_device)
# Predict
with torch.no_grad():
output = model(input_tensor)
pred_idx = int(torch.argmax(output, dim=1).item())
confidence = float(torch.softmax(output, dim=1)[0][pred_idx].item())
# Grad-CAM: choose a sensible target layer
# For DenseNet, a typical target is model.features.denseblock4 or the final features element
try:
# prefer denseblock4 if present
target_layer = getattr(model.features, "denseblock4", None)
if target_layer is None:
# fallback to last features module
target_layer = model.features[-1]
except Exception:
target_layer = model.features
gradcam = GradCAM(model, target_layer=target_layer)
# Note: GradCAM.generate returns (cam_resized, probs, pred_idx) in the robust gradcam.py
cam_map, probs, returned_idx = gradcam.generate(input_tensor, class_idx=pred_idx)
# cam_map is a numpy array normalized 0..1 with shape (H, W)
# Prepare overlay image
# Resize original image to 224x224 and convert to numpy RGB
orig_np = np.array(image.resize((224, 224))).astype(np.uint8)
# Convert cam_map (0..1) to heatmap (0..255) then to colored map
heatmap = np.uint8(255 * cam_map)
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
# Blend overlay (weights can be tuned)
overlay = (0.6 * orig_np.astype(np.float32) + 0.4 * heatmap_color.astype(np.float32))
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
# Save overlay using a distinct filename
cam_filename = f"cam_{filename}"
cam_path = os.path.join(app.config['UPLOAD_FOLDER'], cam_filename)
# cv2.imwrite expects BGR, convert overlay RGB->BGR
cv2.imwrite(cam_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
return render_template(
'result.html',
prediction=classes[pred_idx] if pred_idx < len(classes) else str(pred_idx),
confidence=f"{confidence * 100:.2f}%",
uploaded_image=filename,
cam_image=cam_filename
)
except Exception as e:
# Log trace for debugging
tb = traceback.format_exc()
print("Error during prediction:", e)
print(tb)
return render_template('error.html', error_message=str(e)), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
# In production debug should be False
app.run(host="0.0.0.0", port=port, debug=True)
|