KYAGABA commited on
Commit
e072d50
·
verified ·
1 Parent(s): 3ccfada

test_on_gpu

Browse files
Files changed (1) hide show
  1. app.py +57 -44
app.py CHANGED
@@ -67,6 +67,9 @@ try:
67
  model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
68
  state_dict = torch.load(model_file, map_location=device)
69
  combined_model.load_state_dict(state_dict)
 
 
 
70
  combined_model.eval()
71
  except Exception as e:
72
  raise SystemExit(f"Error loading models: {e}")
@@ -81,48 +84,58 @@ class_names = ["acute", "normal", "chronic", "lacunar"]
81
 
82
  @app.post("/predict/")
83
  async def predict(files: list[UploadFile]):
84
- print(f"Received {len(files)} files")
85
- n_frames = 16
86
- images = []
87
-
88
- for file in files:
89
- ext = file.filename.split('.')[-1].lower()
90
- try:
91
- if ext in ['dcm', 'ima']:
92
- dicom_img = dicom_to_png(await file.read())
93
- images.append(dicom_img.convert("RGB"))
94
- elif ext in ['png', 'jpeg', 'jpg']:
95
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
96
- images.append(img)
97
- else:
98
- raise HTTPException(status_code=400, detail="Unsupported file type.")
99
- except Exception as e:
100
- raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
101
-
102
- if not images:
103
- return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
104
-
105
- if len(images) >= n_frames:
106
- images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
107
- else:
108
- images_sampled = images + [images[-1]] * (n_frames - len(images))
109
-
110
- image_tensors = [image_transform(img) for img in images_sampled]
111
- images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)
112
-
113
- with torch.no_grad():
114
- class_outputs, generated_report, _ = combined_model(images_tensor)
115
- predicted_class = torch.argmax(class_outputs, dim=1).item()
116
- predicted_class_name = class_names[predicted_class]
117
-
118
- gc.collect()
119
- if torch.cuda.is_available():
120
- torch.cuda.empty_cache()
121
 
122
- return {
123
- "predicted_class": predicted_class_name,
124
- "generated_report": generated_report[0] if generated_report else "No report generated."
125
- }
126
-
127
- if __name__ == "__main__":
128
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
68
  state_dict = torch.load(model_file, map_location=device)
69
  combined_model.load_state_dict(state_dict)
70
+
71
+ # Move model to device
72
+ combined_model = combined_model.to(device)
73
  combined_model.eval()
74
  except Exception as e:
75
  raise SystemExit(f"Error loading models: {e}")
 
84
 
85
  @app.post("/predict/")
86
  async def predict(files: list[UploadFile]):
87
+ try:
88
+ print(f"Received {len(files)} files")
89
+ n_frames = 16
90
+ images = []
91
+
92
+ for file in files:
93
+ ext = file.filename.split('.')[-1].lower()
94
+ try:
95
+ if ext in ['dcm', 'ima']:
96
+ dicom_img = dicom_to_png(await file.read())
97
+ images.append(dicom_img.convert("RGB"))
98
+ elif ext in ['png', 'jpeg', 'jpg']:
99
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
100
+ images.append(img)
101
+ else:
102
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
105
+
106
+ if not images:
107
+ return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
108
+
109
+ if len(images) >= n_frames:
110
+ images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
111
+ else:
112
+ images_sampled = images + [images[-1]] * (n_frames - len(images))
113
+
114
+ image_tensors = [image_transform(img) for img in images_sampled]
115
+ images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0)
 
 
 
 
 
 
 
 
116
 
117
+ # Ensure tensor is on the same device as model
118
+ images_tensor = images_tensor.to(device)
119
+
120
+ with torch.no_grad():
121
+ class_outputs, generated_report, _ = combined_model(images_tensor)
122
+ predicted_class = torch.argmax(class_outputs, dim=1).item()
123
+ predicted_class_name = class_names[predicted_class]
124
+
125
+ gc.collect()
126
+ if torch.cuda.is_available():
127
+ torch.cuda.empty_cache()
128
+
129
+ return {
130
+ "predicted_class": predicted_class_name,
131
+ "generated_report": generated_report[0] if generated_report else "No report generated."
132
+ }
133
+
134
+ except Exception as e:
135
+ import traceback
136
+ error_details = traceback.format_exc()
137
+ print(f"Error during prediction: {str(e)}\n{error_details}")
138
+ return JSONResponse(
139
+ status_code=500,
140
+ content={"error": f"Error processing request: {str(e)}"}
141
+ )