thehammadishaq commited on
Commit
3924835
·
verified ·
1 Parent(s): 4fef962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -35
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse, StreamingResponse
4
- from fastapi.staticfiles import StaticFiles
5
  from slowapi import Limiter
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
@@ -17,33 +16,22 @@ import cv2
17
  import io
18
  import uuid
19
  from datetime import datetime, timedelta
20
- from pathlib import Path
21
- import os
22
 
23
  # Configuration
24
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
25
- HEATMAP_EXPIRY = 300 # 5 minutes in seconds
26
  PORT = 7860
27
- HEATMAP_DIR = "/tmp/heatmaps" # Changed to writable /tmp directory
28
 
29
- # Initialize FastAPI
30
  app = FastAPI(
31
  title="ChexNet Medical Imaging API",
32
  description="API for chest X-ray analysis with Grad-CAM visualization",
33
- version="4.1.0"
34
  )
35
 
36
  # Rate limiter setup
37
  limiter = Limiter(key_func=get_remote_address)
38
  app.state.limiter = limiter
39
 
40
- # Create heatmap directory (in /tmp which is writable)
41
- heatmap_dir = Path(HEATMAP_DIR)
42
- heatmap_dir.mkdir(parents=True, exist_ok=True)
43
-
44
- # Mount static files from /tmp
45
- app.mount("/static/heatmap", StaticFiles(directory=HEATMAP_DIR), name="heatmaps")
46
-
47
  # CORS configuration
48
  app.add_middleware(
49
  CORSMiddleware,
@@ -94,16 +82,6 @@ except Exception as e:
94
  print(f"❌ Model loading failed: {e}")
95
  raise
96
 
97
- async def cleanup_old_heatmaps():
98
- now = datetime.now()
99
- for file in heatmap_dir.glob("*.png"):
100
- file_time = datetime.fromtimestamp(file.stat().st_mtime)
101
- if (now - file_time) > timedelta(seconds=HEATMAP_EXPIRY):
102
- try:
103
- file.unlink()
104
- except Exception as e:
105
- print(f"Error deleting {file.name}: {e}")
106
-
107
  def generate_gradcam(img):
108
  img_array = img_to_array(img)
109
  img_array = np.expand_dims(img_array, axis=0)
@@ -149,7 +127,6 @@ def preprocess_image(file_bytes):
149
  @limiter.limit("5/minute")
150
  async def analyze_image(
151
  request: Request,
152
- background_tasks: BackgroundTasks,
153
  file: UploadFile = File(...)
154
  ):
155
  if not file.content_type.startswith('image/'):
@@ -171,16 +148,15 @@ async def analyze_image(
171
 
172
  heatmap = generate_gradcam(img)
173
 
174
- session_id = str(uuid.uuid4())
175
- heatmap_path = heatmap_dir / f"{session_id}.png"
176
- heatmap.save(heatmap_path)
177
-
178
- background_tasks.add_task(cleanup_old_heatmaps)
179
 
180
  return {
181
- "session_id": session_id,
182
  "predictions": decoded[0],
183
- "heatmap_url": f"https://{request.url.hostname}/static/heatmap/{session_id}.png"
 
184
  }
185
  except Exception as e:
186
  raise HTTPException(500, f"Analysis failed: {str(e)}")
@@ -189,8 +165,7 @@ async def analyze_image(
189
  async def health_check():
190
  return {
191
  "status": "healthy",
192
- "timestamp": datetime.now().isoformat(),
193
- "heatmap_files": len(list(heatmap_dir.glob("*.png")))
194
  }
195
 
196
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
 
4
  from slowapi import Limiter
5
  from slowapi.util import get_remote_address
6
  import tensorflow as tf
 
16
  import io
17
  import uuid
18
  from datetime import datetime, timedelta
19
+ import base64
 
20
 
21
  # Configuration
22
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
 
23
  PORT = 7860
 
24
 
 
25
  app = FastAPI(
26
  title="ChexNet Medical Imaging API",
27
  description="API for chest X-ray analysis with Grad-CAM visualization",
28
+ version="5.0.0"
29
  )
30
 
31
  # Rate limiter setup
32
  limiter = Limiter(key_func=get_remote_address)
33
  app.state.limiter = limiter
34
 
 
 
 
 
 
 
 
35
  # CORS configuration
36
  app.add_middleware(
37
  CORSMiddleware,
 
82
  print(f"❌ Model loading failed: {e}")
83
  raise
84
 
 
 
 
 
 
 
 
 
 
 
85
  def generate_gradcam(img):
86
  img_array = img_to_array(img)
87
  img_array = np.expand_dims(img_array, axis=0)
 
127
  @limiter.limit("5/minute")
128
  async def analyze_image(
129
  request: Request,
 
130
  file: UploadFile = File(...)
131
  ):
132
  if not file.content_type.startswith('image/'):
 
148
 
149
  heatmap = generate_gradcam(img)
150
 
151
+ # Convert heatmap to base64 instead of saving to file
152
+ img_byte_arr = io.BytesIO()
153
+ heatmap.save(img_byte_arr, format='PNG')
154
+ heatmap_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
 
155
 
156
  return {
 
157
  "predictions": decoded[0],
158
+ "heatmap_image": heatmap_base64,
159
+ "heatmap_format": "base64 encoded PNG"
160
  }
161
  except Exception as e:
162
  raise HTTPException(500, f"Analysis failed: {str(e)}")
 
165
  async def health_check():
166
  return {
167
  "status": "healthy",
168
+ "timestamp": datetime.now().isoformat()
 
169
  }
170
 
171
  if __name__ == "__main__":