vscode commited on
Commit
e88f4f2
·
verified ·
1 Parent(s): 0f16270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -29
app.py CHANGED
@@ -1,16 +1,23 @@
1
  import os
 
2
  import torch
3
  from flask import Flask, render_template, request
4
  from PIL import Image
5
  import numpy as np
6
  import cv2
 
 
 
 
 
7
 
8
- from gradcam import GradCAM, model, classes
9
  from torchvision import transforms
10
 
11
  app = Flask(__name__)
12
  UPLOAD_FOLDER = "static/uploads"
13
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
 
14
 
15
  transform = transforms.Compose([
16
  transforms.Resize((224, 224)),
@@ -19,12 +26,27 @@ transform = transforms.Compose([
19
  [0.229, 0.224, 0.225])
20
  ])
21
 
 
 
 
22
  @app.route('/')
23
  def index():
24
  return render_template('index.html')
25
 
26
  @app.route('/predict', methods=['POST'])
27
  def predict():
 
 
 
 
 
 
 
 
 
 
 
 
28
  if 'image' not in request.files:
29
  return "No image uploaded", 400
30
 
@@ -32,37 +54,85 @@ def predict():
32
  if file.filename == '':
33
  return "No selected image", 400
34
 
35
- img_path = os.path.join(UPLOAD_FOLDER, file.filename)
 
 
 
 
 
36
  file.save(img_path)
37
 
38
- image = Image.open(img_path).convert("RGB")
39
- input_tensor = transform(image).unsqueeze(0).to(next(model.parameters()).device)
40
-
41
- # Predict
42
- with torch.no_grad():
43
- output = model(input_tensor)
44
- pred_idx = torch.argmax(output, dim=1).item()
45
- confidence = torch.softmax(output, dim=1)[0][pred_idx].item()
46
-
47
- # Grad-CAM
48
- gradcam = GradCAM(model, model.features.denseblock4)
49
- cam = gradcam.generate(input_tensor, class_idx=pred_idx)
50
-
51
- # Prepare overlay
52
- image_np = np.array(image.resize((224, 224)))
53
- heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
54
- overlay = cv2.addWeighted(image_np, 0.6, heatmap, 0.4, 0)
55
- cam_path = os.path.join(UPLOAD_FOLDER, "cam_" + file.filename)
56
- cv2.imwrite(cam_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
57
-
58
- return render_template(
59
- 'result.html',
60
- prediction=classes[pred_idx],
61
- confidence=f"{confidence * 100:.2f}%",
62
- uploaded_image=file.filename,
63
- cam_image="cam_" + file.filename
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == '__main__':
67
  port = int(os.environ.get("PORT", 7860))
 
68
  app.run(host="0.0.0.0", port=port, debug=True)
 
1
  import os
2
+ import traceback
3
  import torch
4
  from flask import Flask, render_template, request
5
  from PIL import Image
6
  import numpy as np
7
  import cv2
8
+ from werkzeug.utils import secure_filename
9
+
10
+ # Import from gradcam (uses safe import that won't crash on missing model)
11
+ # gradcam.py must export: GradCAM, model, classes, get_model
12
+ from gradcam import GradCAM, model, classes, get_model
13
 
 
14
  from torchvision import transforms
15
 
16
  app = Flask(__name__)
17
  UPLOAD_FOLDER = "static/uploads"
18
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
19
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
20
+ ALLOWED_EXT = {"png", "jpg", "jpeg", "bmp"}
21
 
22
  transform = transforms.Compose([
23
  transforms.Resize((224, 224)),
 
26
  [0.229, 0.224, 0.225])
27
  ])
28
 
29
+ def allowed_file(filename):
30
+ return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXT
31
+
32
  @app.route('/')
33
  def index():
34
  return render_template('index.html')
35
 
36
  @app.route('/predict', methods=['POST'])
37
  def predict():
38
+ # Ensure model is loaded (try lazy load if module-level model was None)
39
+ global model
40
+ if model is None:
41
+ model = get_model(reload=True)
42
+ if model is None:
43
+ # Friendly error page — you can make a nicer HTML template if you want
44
+ err = (
45
+ "Model is not available. Please upload a valid `model.pth` to the Space "
46
+ "or check the application logs for details."
47
+ )
48
+ return render_template('error.html', error_message=err), 500
49
+
50
  if 'image' not in request.files:
51
  return "No image uploaded", 400
52
 
 
54
  if file.filename == '':
55
  return "No selected image", 400
56
 
57
+ if not allowed_file(file.filename):
58
+ return "Unsupported file type", 400
59
+
60
+ # sanitize filename and save
61
+ filename = secure_filename(file.filename)
62
+ img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
63
  file.save(img_path)
64
 
65
+ try:
66
+ # Load image and preprocess
67
+ image = Image.open(img_path).convert("RGB")
68
+ input_tensor = transform(image).unsqueeze(0)
69
+
70
+ # Move input to the same device as model
71
+ try:
72
+ model_device = next(model.parameters()).device
73
+ except Exception:
74
+ model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+
76
+ input_tensor = input_tensor.to(model_device)
77
+
78
+ # Predict
79
+ with torch.no_grad():
80
+ output = model(input_tensor)
81
+ pred_idx = int(torch.argmax(output, dim=1).item())
82
+ confidence = float(torch.softmax(output, dim=1)[0][pred_idx].item())
83
+
84
+ # Grad-CAM: choose a sensible target layer
85
+ # For DenseNet, a typical target is model.features.denseblock4 or the final features element
86
+ try:
87
+ # prefer denseblock4 if present
88
+ target_layer = getattr(model.features, "denseblock4", None)
89
+ if target_layer is None:
90
+ # fallback to last features module
91
+ target_layer = model.features[-1]
92
+ except Exception:
93
+ target_layer = model.features
94
+
95
+ gradcam = GradCAM(model, target_layer=target_layer)
96
+
97
+ # Note: GradCAM.generate returns (cam_resized, probs, pred_idx) in the robust gradcam.py
98
+ cam_map, probs, returned_idx = gradcam.generate(input_tensor, class_idx=pred_idx)
99
+ # cam_map is a numpy array normalized 0..1 with shape (H, W)
100
+
101
+ # Prepare overlay image
102
+ # Resize original image to 224x224 and convert to numpy RGB
103
+ orig_np = np.array(image.resize((224, 224))).astype(np.uint8)
104
+
105
+ # Convert cam_map (0..1) to heatmap (0..255) then to colored map
106
+ heatmap = np.uint8(255 * cam_map)
107
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
108
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
109
+
110
+ # Blend overlay (weights can be tuned)
111
+ overlay = (0.6 * orig_np.astype(np.float32) + 0.4 * heatmap_color.astype(np.float32))
112
+ overlay = np.clip(overlay, 0, 255).astype(np.uint8)
113
+
114
+ # Save overlay using a distinct filename
115
+ cam_filename = f"cam_{filename}"
116
+ cam_path = os.path.join(app.config['UPLOAD_FOLDER'], cam_filename)
117
+ # cv2.imwrite expects BGR, convert overlay RGB->BGR
118
+ cv2.imwrite(cam_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
119
+
120
+ return render_template(
121
+ 'result.html',
122
+ prediction=classes[pred_idx] if pred_idx < len(classes) else str(pred_idx),
123
+ confidence=f"{confidence * 100:.2f}%",
124
+ uploaded_image=filename,
125
+ cam_image=cam_filename
126
+ )
127
+
128
+ except Exception as e:
129
+ # Log trace for debugging
130
+ tb = traceback.format_exc()
131
+ print("Error during prediction:", e)
132
+ print(tb)
133
+ return render_template('error.html', error_message=str(e)), 500
134
 
135
  if __name__ == '__main__':
136
  port = int(os.environ.get("PORT", 7860))
137
+ # In production debug should be False
138
  app.run(host="0.0.0.0", port=port, debug=True)