detect-glocom / main.py
Stroke-ia's picture
Update main.py
f823b59 verified
import os
import cv2
import time
import torch
import numpy as np
import threading
from PIL import Image
from datetime import datetime
from fastapi import FastAPI, UploadFile, File
from fastapi.staticfiles import StaticFiles
import torchvision.transforms as transforms
from torch.nn import functional as F
# -----------------------------
# 1. Environnement et config
# -----------------------------
os.environ["TORCH_HOME"] = "/tmp/torch"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
from pipline import Transformer_Regression, extract_regions_Last, compute_ratios
MODEL_PATH = "TrainAll_Maghrabi84_50iteration_SWIN.pth.tar"
OUTPUT_DIR = "/tmp/outputs"
BASE_URL = "https://stroke-ia-detect-glocom.hf.space" # ⚠️ à adapter à ton domaine
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# -----------------------------
# 2. Initialisation modèle
# -----------------------------
image_shape = 384
dim_patch = 4
scale = 1
DeepLab = Transformer_Regression(
image_dim=image_shape, dim_patch=dim_patch, num_classes=3, scale=scale, feat_dim=128
)
DeepLab.to(device)
DeepLab.load_state_dict(torch.load(MODEL_PATH, map_location=device))
DeepLab.eval()
# -----------------------------
# 3. Prétraitement
# -----------------------------
tfms = transforms.Compose([
transforms.Resize((image_shape, image_shape)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
# -----------------------------
# 4. Inférence
# -----------------------------
def Final_Compute_regression_results_Sample(Model, batch_sampler, num_head=2):
Model.eval()
with torch.no_grad():
train_batch_tfms = batch_sampler['image'].to(device)
ytrue_seg = batch_sampler['image_original']
scores = Model(train_batch_tfms.unsqueeze(0))
yseg_pred = F.interpolate(scores['seg'],
size=(ytrue_seg.shape[0], ytrue_seg.shape[1]),
mode='bilinear', align_corners=True)
Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']),
yseg_pred.argmax(1).long()[0].cpu().numpy())
Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB')
if num_head == 2:
scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device))
yseg_pred_crop = F.interpolate(scores['seg_aux_1'],
size=(Regions_crop['image'].size[1],
Regions_crop['image'].size[0]),
mode='bilinear', align_corners=True)
yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop
yseg_pred = torch.softmax(yseg_pred, dim=1)
yseg_pred = yseg_pred.argmax(1).long().cpu().numpy()
ratios = compute_ratios(yseg_pred[0])
p_img = batch_sampler['image'].to(device).unsqueeze(0)
p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]),
mode='bilinear', align_corners=True)
image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).cpu().numpy()
image_orig = np.uint8(image_orig * 255)
image_cont = image_orig.copy()
# Contours
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0)
conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2)
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0)
conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2)
if ratios.vcdr < 0.6:
glaucoma = "None"
else:
glaucoma = "May be there is a risk of Glaucoma"
return image_cont, ratios.vcdr, glaucoma, Regions_crop
# -----------------------------
# 5. FastAPI app
# -----------------------------
app = FastAPI(title="Glaucoma Detection API")
app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
@app.post("/predict/")
async def predict(image_file: UploadFile = File(...)):
tmp_path = f"/tmp/{image_file.filename}"
with open(tmp_path, "wb") as f:
f.write(await image_file.read())
img = np.array(Image.open(tmp_path).convert("RGB"))
sample_batch = {
"image_original": img,
"image": tfms(Image.fromarray(img))
}
result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2)
cropped_img = result[cropped['cord'][0]:cropped['cord'][1],
cropped['cord'][2]:cropped['cord'][3]]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_img_name = f"glaucoma_result_{timestamp}.png"
out_zoom_name = f"glaucoma_zoom_{timestamp}.png"
out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
out_zoom_path = os.path.join(OUTPUT_DIR, out_zoom_name)
cv2.imwrite(out_img_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
cv2.imwrite(out_zoom_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
os.remove(tmp_path)
return {
"ratio": round(float(ratio), 3),
"diagnosis": diagnosis,
"overlay_url": f"{BASE_URL}/files/{out_img_name}",
"zoom_url": f"{BASE_URL}/files/{out_zoom_name}",
"message": "✅ Glaucoma analysis complete"
}
# -----------------------------
# 6. Auto-cleanup (toutes les 10 min)
# -----------------------------
def auto_cleanup(interval_minutes=10):
while True:
time.sleep(interval_minutes * 60)
for filename in os.listdir(OUTPUT_DIR):
path = os.path.join(OUTPUT_DIR, filename)
try:
if os.path.isfile(path):
os.remove(path)
print(f"[CLEANUP] Removed {path}")
except Exception as e:
print(f"[CLEANUP] Error removing {path}: {e}")
threading.Thread(target=auto_cleanup, daemon=True).start()