Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import time | |
| import torch | |
| import numpy as np | |
| import threading | |
| from PIL import Image | |
| from datetime import datetime | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.staticfiles import StaticFiles | |
| import torchvision.transforms as transforms | |
| from torch.nn import functional as F | |
| # ----------------------------- | |
| # 1. Environnement et config | |
| # ----------------------------- | |
| os.environ["TORCH_HOME"] = "/tmp/torch" | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
| os.environ["XDG_CACHE_HOME"] = "/tmp" | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| from pipline import Transformer_Regression, extract_regions_Last, compute_ratios | |
| MODEL_PATH = "TrainAll_Maghrabi84_50iteration_SWIN.pth.tar" | |
| OUTPUT_DIR = "/tmp/outputs" | |
| BASE_URL = "https://stroke-ia-detect-glocom.hf.space" # ⚠️ à adapter à ton domaine | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ----------------------------- | |
| # 2. Initialisation modèle | |
| # ----------------------------- | |
| image_shape = 384 | |
| dim_patch = 4 | |
| scale = 1 | |
| DeepLab = Transformer_Regression( | |
| image_dim=image_shape, dim_patch=dim_patch, num_classes=3, scale=scale, feat_dim=128 | |
| ) | |
| DeepLab.to(device) | |
| DeepLab.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| DeepLab.eval() | |
| # ----------------------------- | |
| # 3. Prétraitement | |
| # ----------------------------- | |
| tfms = transforms.Compose([ | |
| transforms.Resize((image_shape, image_shape)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(0.5, 0.5) | |
| ]) | |
| # ----------------------------- | |
| # 4. Inférence | |
| # ----------------------------- | |
| def Final_Compute_regression_results_Sample(Model, batch_sampler, num_head=2): | |
| Model.eval() | |
| with torch.no_grad(): | |
| train_batch_tfms = batch_sampler['image'].to(device) | |
| ytrue_seg = batch_sampler['image_original'] | |
| scores = Model(train_batch_tfms.unsqueeze(0)) | |
| yseg_pred = F.interpolate(scores['seg'], | |
| size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), | |
| mode='bilinear', align_corners=True) | |
| Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']), | |
| yseg_pred.argmax(1).long()[0].cpu().numpy()) | |
| Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB') | |
| if num_head == 2: | |
| scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device)) | |
| yseg_pred_crop = F.interpolate(scores['seg_aux_1'], | |
| size=(Regions_crop['image'].size[1], | |
| Regions_crop['image'].size[0]), | |
| mode='bilinear', align_corners=True) | |
| yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
| Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop | |
| yseg_pred = torch.softmax(yseg_pred, dim=1) | |
| yseg_pred = yseg_pred.argmax(1).long().cpu().numpy() | |
| ratios = compute_ratios(yseg_pred[0]) | |
| p_img = batch_sampler['image'].to(device).unsqueeze(0) | |
| p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]), | |
| mode='bilinear', align_corners=True) | |
| image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).cpu().numpy() | |
| image_orig = np.uint8(image_orig * 255) | |
| image_cont = image_orig.copy() | |
| # Contours | |
| ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0) | |
| conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2) | |
| ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0) | |
| conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2) | |
| if ratios.vcdr < 0.6: | |
| glaucoma = "None" | |
| else: | |
| glaucoma = "May be there is a risk of Glaucoma" | |
| return image_cont, ratios.vcdr, glaucoma, Regions_crop | |
| # ----------------------------- | |
| # 5. FastAPI app | |
| # ----------------------------- | |
| app = FastAPI(title="Glaucoma Detection API") | |
| app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files") | |
| async def predict(image_file: UploadFile = File(...)): | |
| tmp_path = f"/tmp/{image_file.filename}" | |
| with open(tmp_path, "wb") as f: | |
| f.write(await image_file.read()) | |
| img = np.array(Image.open(tmp_path).convert("RGB")) | |
| sample_batch = { | |
| "image_original": img, | |
| "image": tfms(Image.fromarray(img)) | |
| } | |
| result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2) | |
| cropped_img = result[cropped['cord'][0]:cropped['cord'][1], | |
| cropped['cord'][2]:cropped['cord'][3]] | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_img_name = f"glaucoma_result_{timestamp}.png" | |
| out_zoom_name = f"glaucoma_zoom_{timestamp}.png" | |
| out_img_path = os.path.join(OUTPUT_DIR, out_img_name) | |
| out_zoom_path = os.path.join(OUTPUT_DIR, out_zoom_name) | |
| cv2.imwrite(out_img_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) | |
| cv2.imwrite(out_zoom_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR)) | |
| os.remove(tmp_path) | |
| return { | |
| "ratio": round(float(ratio), 3), | |
| "diagnosis": diagnosis, | |
| "overlay_url": f"{BASE_URL}/files/{out_img_name}", | |
| "zoom_url": f"{BASE_URL}/files/{out_zoom_name}", | |
| "message": "✅ Glaucoma analysis complete" | |
| } | |
| # ----------------------------- | |
| # 6. Auto-cleanup (toutes les 10 min) | |
| # ----------------------------- | |
| def auto_cleanup(interval_minutes=10): | |
| while True: | |
| time.sleep(interval_minutes * 60) | |
| for filename in os.listdir(OUTPUT_DIR): | |
| path = os.path.join(OUTPUT_DIR, filename) | |
| try: | |
| if os.path.isfile(path): | |
| os.remove(path) | |
| print(f"[CLEANUP] Removed {path}") | |
| except Exception as e: | |
| print(f"[CLEANUP] Error removing {path}: {e}") | |
| threading.Thread(target=auto_cleanup, daemon=True).start() | |