test_on_gpu
Browse files
app.py
CHANGED
|
@@ -67,6 +67,9 @@ try:
|
|
| 67 |
model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
|
| 68 |
state_dict = torch.load(model_file, map_location=device)
|
| 69 |
combined_model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
| 70 |
combined_model.eval()
|
| 71 |
except Exception as e:
|
| 72 |
raise SystemExit(f"Error loading models: {e}")
|
|
@@ -81,48 +84,58 @@ class_names = ["acute", "normal", "chronic", "lacunar"]
|
|
| 81 |
|
| 82 |
@app.post("/predict/")
|
| 83 |
async def predict(files: list[UploadFile]):
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
with torch.no_grad():
|
| 114 |
-
class_outputs, generated_report, _ = combined_model(images_tensor)
|
| 115 |
-
predicted_class = torch.argmax(class_outputs, dim=1).item()
|
| 116 |
-
predicted_class_name = class_names[predicted_class]
|
| 117 |
-
|
| 118 |
-
gc.collect()
|
| 119 |
-
if torch.cuda.is_available():
|
| 120 |
-
torch.cuda.empty_cache()
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
|
| 68 |
state_dict = torch.load(model_file, map_location=device)
|
| 69 |
combined_model.load_state_dict(state_dict)
|
| 70 |
+
|
| 71 |
+
# Move model to device
|
| 72 |
+
combined_model = combined_model.to(device)
|
| 73 |
combined_model.eval()
|
| 74 |
except Exception as e:
|
| 75 |
raise SystemExit(f"Error loading models: {e}")
|
|
|
|
| 84 |
|
| 85 |
@app.post("/predict/")
|
| 86 |
async def predict(files: list[UploadFile]):
|
| 87 |
+
try:
|
| 88 |
+
print(f"Received {len(files)} files")
|
| 89 |
+
n_frames = 16
|
| 90 |
+
images = []
|
| 91 |
+
|
| 92 |
+
for file in files:
|
| 93 |
+
ext = file.filename.split('.')[-1].lower()
|
| 94 |
+
try:
|
| 95 |
+
if ext in ['dcm', 'ima']:
|
| 96 |
+
dicom_img = dicom_to_png(await file.read())
|
| 97 |
+
images.append(dicom_img.convert("RGB"))
|
| 98 |
+
elif ext in ['png', 'jpeg', 'jpg']:
|
| 99 |
+
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
|
| 100 |
+
images.append(img)
|
| 101 |
+
else:
|
| 102 |
+
raise HTTPException(status_code=400, detail="Unsupported file type.")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
|
| 105 |
+
|
| 106 |
+
if not images:
|
| 107 |
+
return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
|
| 108 |
+
|
| 109 |
+
if len(images) >= n_frames:
|
| 110 |
+
images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
|
| 111 |
+
else:
|
| 112 |
+
images_sampled = images + [images[-1]] * (n_frames - len(images))
|
| 113 |
+
|
| 114 |
+
image_tensors = [image_transform(img) for img in images_sampled]
|
| 115 |
+
images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
# Ensure tensor is on the same device as model
|
| 118 |
+
images_tensor = images_tensor.to(device)
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
class_outputs, generated_report, _ = combined_model(images_tensor)
|
| 122 |
+
predicted_class = torch.argmax(class_outputs, dim=1).item()
|
| 123 |
+
predicted_class_name = class_names[predicted_class]
|
| 124 |
+
|
| 125 |
+
gc.collect()
|
| 126 |
+
if torch.cuda.is_available():
|
| 127 |
+
torch.cuda.empty_cache()
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"predicted_class": predicted_class_name,
|
| 131 |
+
"generated_report": generated_report[0] if generated_report else "No report generated."
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
import traceback
|
| 136 |
+
error_details = traceback.format_exc()
|
| 137 |
+
print(f"Error during prediction: {str(e)}\n{error_details}")
|
| 138 |
+
return JSONResponse(
|
| 139 |
+
status_code=500,
|
| 140 |
+
content={"error": f"Error processing request: {str(e)}"}
|
| 141 |
+
)
|