Spaces:
Sleeping
Sleeping
| # app.py | |
| import io | |
| import os | |
| import uuid | |
| import threading | |
| import hashlib | |
| from contextvars import ContextVar | |
| from typing import Optional, Dict, Any, List | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from fastapi import FastAPI, UploadFile, File, Query, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| # ============================================================ | |
| # Config | |
| # ============================================================ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_IMG_SIZE = 518 | |
| ALLOW_ORIGINS = os.environ.get("ALLOW_ORIGINS", "*").split(",") | |
| # RAD-DINO checkpoint en HF | |
| RAD_BACKBONE_REPO_ID = "microsoft/rad-dino" | |
| RAD_BACKBONE_FILENAME = "backbone_compatible.safetensors" | |
| # Heads | |
| RAD_HEAD_CKPT_PATH = os.environ.get("RAD_HEAD_CKPT_PATH", "rad_dino_chestmnist_head.pt") | |
| DINO_HEAD_CKPT_PATH = os.environ.get("DINO_HEAD_CKPT_PATH", "dino_chestmnist_head.pt") | |
| # Normalización | |
| RAD_MEAN = torch.tensor([0.5307, 0.5307, 0.5307], dtype=torch.float32).view(3, 1, 1) | |
| RAD_STD = torch.tensor([0.2583, 0.2583, 0.2583], dtype=torch.float32).view(3, 1, 1) | |
| # DINOv2 usual / ImageNet normalization | |
| DINO_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1) | |
| DINO_STD = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1) | |
| DEFAULT_LABEL_NAMES = [ | |
| "atelectasis", "cardiomegaly", "effusion", "infiltration", | |
| "mass", "nodule", "pneumonia", "pneumothorax", | |
| "consolidation", "edema", "emphysema", "fibrosis", | |
| "pleural", "hernia" | |
| ] | |
| MODEL_CONFIGS = { | |
| "rad-dino": { | |
| "backbone_type": "rad-dino", | |
| "head_ckpt_path": RAD_HEAD_CKPT_PATH, | |
| "model_name": "rad-dino-chestmnist", | |
| "mean": RAD_MEAN, | |
| "std": RAD_STD, | |
| }, | |
| "dino": { | |
| "backbone_type": "dino", | |
| "head_ckpt_path": DINO_HEAD_CKPT_PATH, | |
| "model_name": "dino-chestmnist", | |
| "mean": DINO_MEAN, | |
| "std": DINO_STD, | |
| }, | |
| } | |
| # ============================================================ | |
| # Model definitions | |
| # ============================================================ | |
| class MedicalHead(nn.Module): | |
| def __init__(self, in_dim: int = 768, num_classes: int = 14, dropout: float = 0.1): | |
| super().__init__() | |
| self.drop = nn.Dropout(dropout) | |
| self.fc = nn.Linear(in_dim, num_classes) | |
| def forward(self, cls_token: torch.Tensor) -> torch.Tensor: | |
| return self.fc(self.drop(cls_token)) | |
| def round_tensor(t: torch.Tensor, decimals: int = 4) -> torch.Tensor: | |
| s = 10 ** decimals | |
| return torch.round(t * s) / s | |
| def preprocess_pil(pil_img: Image.Image, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: | |
| img = pil_img.convert("RGB").resize((MODEL_IMG_SIZE, MODEL_IMG_SIZE), Image.BICUBIC) | |
| arr = np.array(img).astype("float32") / 255.0 | |
| x = torch.from_numpy(arr).permute(2, 0, 1) # [3,H,W] | |
| x = (x - mean) / std | |
| return x.unsqueeze(0) # [1,3,H,W] | |
| # ============================================================ | |
| # Build backbones | |
| # ============================================================ | |
| def ensure_local_dinov2_repo(): | |
| if not os.path.exists("./dinov2"): | |
| raise FileNotFoundError( | |
| "No encontré ./dinov2. Clona el repo primero con:\n" | |
| "git clone https://github.com/facebookresearch/dinov2.git" | |
| ) | |
| def disable_fused_attn(model: nn.Module): | |
| for blk in model.blocks: | |
| if hasattr(blk.attn, "fused_attn"): | |
| blk.attn.fused_attn = False | |
| def build_dinov2_backbone() -> nn.Module: | |
| ensure_local_dinov2_repo() | |
| model = torch.hub.load("./dinov2", "dinov2_vitb14", source="local") | |
| model.eval().to(DEVICE) | |
| disable_fused_attn(model) | |
| return model | |
| def build_rad_dino_backbone() -> nn.Module: | |
| model = build_dinov2_backbone() | |
| backbone_path = hf_hub_download( | |
| repo_id=RAD_BACKBONE_REPO_ID, | |
| filename=RAD_BACKBONE_FILENAME | |
| ) | |
| state = load_file(backbone_path) | |
| model.load_state_dict(state, strict=True) | |
| model.eval().to(DEVICE) | |
| disable_fused_attn(model) | |
| return model | |
| def build_head(head_ckpt_path: str) -> tuple[nn.Module, Dict[str, Any], List[str]]: | |
| ckpt = torch.load(head_ckpt_path, map_location=DEVICE) | |
| label_names = ckpt.get("label_names", DEFAULT_LABEL_NAMES) | |
| num_classes = len(label_names) | |
| head = MedicalHead(in_dim=768, num_classes=num_classes, dropout=0.1).to(DEVICE) | |
| head.load_state_dict(ckpt["head_state_dict"]) | |
| head.eval() | |
| return head, ckpt, label_names | |
| def build_model_bundle(model_key: str, cfg: Dict[str, Any]) -> Dict[str, Any]: | |
| if cfg["backbone_type"] == "rad-dino": | |
| backbone = build_rad_dino_backbone() | |
| elif cfg["backbone_type"] == "dino": | |
| backbone = build_dinov2_backbone() | |
| else: | |
| raise ValueError(f"backbone_type desconocido: {cfg['backbone_type']}") | |
| head, ckpt, label_names = build_head(cfg["head_ckpt_path"]) | |
| bundle = { | |
| "key": model_key, | |
| "model_name": cfg["model_name"], | |
| "backbone_type": cfg["backbone_type"], | |
| "backbone": backbone, | |
| "head": head, | |
| "head_ckpt": ckpt, | |
| "label_names": label_names, | |
| "mean": cfg["mean"], | |
| "std": cfg["std"], | |
| "num_layers": len(backbone.blocks), | |
| "num_heads": getattr(backbone.blocks[0].attn, "num_heads", None), | |
| "current": { | |
| "hash": None, | |
| "attention_cls_full": None, | |
| "logit_lens_full": None, | |
| }, | |
| "results": {}, | |
| "lock": threading.Lock(), | |
| } | |
| return bundle | |
| # ============================================================ | |
| # Hook registration per model | |
| # ============================================================ | |
| def register_hooks(bundle: Dict[str, Any]): | |
| _attn_in_var: ContextVar[Optional[list]] = ContextVar( | |
| f"_attn_in_var_{bundle['key']}", default=None | |
| ) | |
| _tok_var: ContextVar[Optional[list]] = ContextVar( | |
| f"_tok_var_{bundle['key']}", default=None | |
| ) | |
| def _save_attn_input(module, inp): | |
| lst = _attn_in_var.get() | |
| if lst is None: | |
| return | |
| if len(inp) == 0 or not torch.is_tensor(inp[0]): | |
| return | |
| # input to attn: [B, N, D] | |
| lst.append(inp[0].detach()) | |
| def _save_block_out(module, inp, out): | |
| lst = _tok_var.get() | |
| if lst is None: | |
| return | |
| if torch.is_tensor(out): | |
| # block output: [B, N, D] | |
| lst.append(out.detach()) | |
| attn_hooks = [] | |
| tok_hooks = [] | |
| for blk in bundle["backbone"].blocks: | |
| if not hasattr(blk, "attn"): | |
| raise RuntimeError(f"No encontré blk.attn en backbone {bundle['key']}") | |
| attn_hooks.append(blk.attn.register_forward_pre_hook(_save_attn_input)) | |
| tok_hooks.append(blk.register_forward_hook(_save_block_out)) | |
| bundle["_attn_in_var"] = _attn_in_var | |
| bundle["_tok_var"] = _tok_var | |
| bundle["_attn_hooks"] = attn_hooks | |
| bundle["_tok_hooks"] = tok_hooks | |
| # ============================================================ | |
| # Build all models | |
| # ============================================================ | |
| MODELS: Dict[str, Dict[str, Any]] = { | |
| key: build_model_bundle(key, cfg) | |
| for key, cfg in MODEL_CONFIGS.items() | |
| } | |
| for _bundle in MODELS.values(): | |
| register_hooks(_bundle) | |
| for key, bundle in MODELS.items(): | |
| print(f"[server] model_key={key}") | |
| print(f"[server] model_name={bundle['model_name']}") | |
| print(f"[server] backbone_type={bundle['backbone_type']} device={DEVICE}") | |
| print(f"[server] head_ckpt={MODEL_CONFIGS[key]['head_ckpt_path']}") | |
| print(f"[server] num_layers={bundle['num_layers']} num_heads={bundle['num_heads']}") | |
| print(f"[server] num_classes={len(bundle['label_names'])}") | |
| if "best_val_auc" in bundle["head_ckpt"]: | |
| print(f"[server] checkpoint best_val_auc={bundle['head_ckpt']['best_val_auc']:.4f}") | |
| # ============================================================ | |
| # Inference helpers | |
| # ============================================================ | |
| def extract_cls(backbone: nn.Module, images: torch.Tensor) -> torch.Tensor: | |
| feats = backbone.forward_features(images) | |
| if "x_norm_clstoken" not in feats: | |
| raise RuntimeError("forward_features no devolvió 'x_norm_clstoken'.") | |
| return feats["x_norm_clstoken"] | |
| def compute_logit_lens_from_tokens(tokens_per_layer: List[torch.Tensor], head: nn.Module): | |
| logits_list = [] | |
| probs_list = [] | |
| for x_l in tokens_per_layer: | |
| # x_l: [B, N, D] | |
| cls_l = x_l[:, 0] # [B, D] | |
| logits_l = head(cls_l) # [B, C] | |
| probs_l = torch.sigmoid(logits_l) | |
| logits_list.append(logits_l.detach().cpu()) | |
| probs_list.append(probs_l.detach().cpu()) | |
| logits_per_layer = torch.stack(logits_list, dim=0) # [L, B, C] | |
| probs_per_layer = torch.stack(probs_list, dim=0) # [L, B, C] | |
| return logits_per_layer, probs_per_layer | |
| def compute_cls_attention_from_inputs(backbone: nn.Module, attn_inputs: List[torch.Tensor]): | |
| """ | |
| Reconstruct CLS->tokens attention per layer from the input to attention. | |
| Returns list of [B, H, N], one per layer. | |
| """ | |
| cls_attn_per_layer = [] | |
| for blk, x in zip(backbone.blocks, attn_inputs): | |
| x = x.to(DEVICE) # [B, N, D] | |
| B, N, C = x.shape | |
| num_heads = blk.attn.num_heads | |
| head_dim = C // num_heads | |
| qkv = blk.attn.qkv(x) # [B, N, 3*C] | |
| qkv = qkv.reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] # [B, H, N, Hd] | |
| attn = (q @ k.transpose(-2, -1)) * blk.attn.scale | |
| attn = attn.softmax(dim=-1) | |
| cls_attn = attn[:, :, 0, :].detach().cpu() # [B, H, N] | |
| cls_attn_per_layer.append(cls_attn) | |
| return cls_attn_per_layer | |
| def analyze_image(bundle: Dict[str, Any], pil_img: Image.Image) -> Dict[str, Any]: | |
| x = preprocess_pil(pil_img, bundle["mean"], bundle["std"]).to(DEVICE) | |
| attn_inputs = [] | |
| layer_tokens = [] | |
| tok_token = bundle["_tok_var"].set(layer_tokens) | |
| attn_token = bundle["_attn_in_var"].set(attn_inputs) | |
| try: | |
| with torch.no_grad(): | |
| with bundle["lock"]: | |
| cls_final = extract_cls(bundle["backbone"], x) # [1, 768] | |
| logits_final = bundle["head"](cls_final) # [1, C] | |
| probs_final = torch.sigmoid(logits_final)[0].detach().cpu() | |
| probs_final = round_tensor(probs_final, 6) | |
| if len(layer_tokens) == 0: | |
| raise RuntimeError("No se capturaron tokens por capa.") | |
| if len(attn_inputs) == 0: | |
| raise RuntimeError("No se capturaron entradas a atención por capa.") | |
| logits_by_layer, probs_by_layer = compute_logit_lens_from_tokens( | |
| layer_tokens, bundle["head"] | |
| ) | |
| attn_maps = compute_cls_attention_from_inputs(bundle["backbone"], attn_inputs) | |
| # ---------------------------------------------------- | |
| # attention_cls_full | |
| # ---------------------------------------------------- | |
| attn_maps2 = [a.squeeze(0) for a in attn_maps] # list of [H, N] | |
| attn_serializable_all = [] | |
| attn_serializable_patches = [] | |
| for layer in attn_maps2: | |
| layer_all = [] | |
| layer_patches = [] | |
| for head in layer: | |
| head = round_tensor(head, 4) # [N] | |
| layer_all.append(head.tolist()) | |
| layer_patches.append(head[1:].tolist()) # remove CLS->CLS | |
| attn_serializable_all.append(layer_all) | |
| attn_serializable_patches.append(layer_patches) | |
| num_tokens_all = len(attn_serializable_all[0][0]) | |
| num_patch_tokens = len(attn_serializable_patches[0][0]) | |
| export_attn = { | |
| "model": bundle["model_name"], | |
| "attention_type": "cls_only", | |
| "num_layers": len(attn_serializable_all), | |
| "num_heads": len(attn_serializable_all[0]), | |
| "num_tokens_all": num_tokens_all, | |
| "num_patch_tokens": num_patch_tokens, | |
| "cls_index": 0, | |
| "attention_cls_to_all_tokens": attn_serializable_all, | |
| "attention_cls_to_patches": attn_serializable_patches, | |
| } | |
| # ---------------------------------------------------- | |
| # logit_lens_full | |
| # ---------------------------------------------------- | |
| export_logit = { | |
| "model": bundle["model_name"], | |
| "num_layers": int(logits_by_layer.shape[0]), | |
| "num_classes": int(logits_by_layer.shape[-1]), | |
| "class_names": bundle["label_names"], | |
| "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None), | |
| "final_probs": probs_final.tolist(), | |
| "logits": [], | |
| "probs_by_layer": [], | |
| } | |
| for l in range(logits_by_layer.shape[0]): | |
| v_logits = round_tensor(logits_by_layer[l, 0], 4) | |
| v_probs = round_tensor(probs_by_layer[l, 0], 6) | |
| export_logit["logits"].append(v_logits.tolist()) | |
| export_logit["probs_by_layer"].append(v_probs.tolist()) | |
| return { | |
| "attention_cls_full": export_attn, | |
| "logit_lens_full": export_logit, | |
| } | |
| finally: | |
| bundle["_tok_var"].reset(tok_token) | |
| bundle["_attn_in_var"].reset(attn_token) | |
| layer_tokens.clear() | |
| attn_inputs.clear() | |
| # ============================================================ | |
| # FastAPI app | |
| # ============================================================ | |
| app = FastAPI(title="ChestMNIST Explainer API (RAD-DINO + DINO)", version="2.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOW_ORIGINS if ALLOW_ORIGINS != ["*"] else ["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def _no_store(resp: JSONResponse) -> JSONResponse: | |
| resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" | |
| resp.headers["Pragma"] = "no-cache" | |
| return resp | |
| def get_model_bundle(model_key: str) -> Dict[str, Any]: | |
| if model_key not in MODELS: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Unknown model_key '{model_key}'. Available: {list(MODELS.keys())}" | |
| ) | |
| return MODELS[model_key] | |
| # ============================================================ | |
| # Root / health | |
| # ============================================================ | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "device": DEVICE, | |
| "available_models": list(MODELS.keys()), | |
| "image_size": MODEL_IMG_SIZE, | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "device": DEVICE, | |
| "available_models": list(MODELS.keys()), | |
| "models": { | |
| key: { | |
| "model": bundle["model_name"], | |
| "num_layers": bundle["num_layers"], | |
| "num_heads": bundle["num_heads"], | |
| "num_classes": len(bundle["label_names"]), | |
| "class_names": bundle["label_names"], | |
| "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None), | |
| "has_current": bundle["current"]["attention_cls_full"] is not None, | |
| } | |
| for key, bundle in MODELS.items() | |
| } | |
| } | |
| def health_model(model_key: str): | |
| bundle = get_model_bundle(model_key) | |
| return { | |
| "status": "ok", | |
| "device": DEVICE, | |
| "model_key": model_key, | |
| "model": bundle["model_name"], | |
| "image_size": MODEL_IMG_SIZE, | |
| "num_layers": bundle["num_layers"], | |
| "num_heads": bundle["num_heads"], | |
| "num_classes": len(bundle["label_names"]), | |
| "class_names": bundle["label_names"], | |
| "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None), | |
| "has_current": bundle["current"]["attention_cls_full"] is not None, | |
| } | |
| # ============================================================ | |
| # Legacy analyze with stored jobs | |
| # ============================================================ | |
| async def analyze( | |
| model_key: str, | |
| file: UploadFile = File(...), | |
| store: int = Query(0, description="1 => guarda resultados y entrega endpoints /results/{model_key}/{id}/..."), | |
| ): | |
| bundle = get_model_bundle(model_key) | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Please upload an image file.") | |
| raw = await file.read() | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not decode image.") | |
| try: | |
| out = analyze_image(bundle, img) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") | |
| if store == 1: | |
| job_id = str(uuid.uuid4()) | |
| bundle["results"][job_id] = out | |
| return { | |
| "model_key": model_key, | |
| "job_id": job_id, | |
| "endpoints": { | |
| "attention_cls_full": f"/results/{model_key}/{job_id}/attention_cls_full.json", | |
| "logit_lens_full": f"/results/{model_key}/{job_id}/logit_lens_full.json", | |
| } | |
| } | |
| return out | |
| def get_attention(model_key: str, job_id: str): | |
| bundle = get_model_bundle(model_key) | |
| if job_id not in bundle["results"]: | |
| raise HTTPException(status_code=404, detail="job_id not found") | |
| return _no_store(JSONResponse(bundle["results"][job_id]["attention_cls_full"])) | |
| def get_logit(model_key: str, job_id: str): | |
| bundle = get_model_bundle(model_key) | |
| if job_id not in bundle["results"]: | |
| raise HTTPException(status_code=404, detail="job_id not found") | |
| return _no_store(JSONResponse(bundle["results"][job_id]["logit_lens_full"])) | |
| # ============================================================ | |
| # Preferred: current endpoints per model | |
| # ============================================================ | |
| async def analyze_current(model_key: str, file: UploadFile = File(...)): | |
| bundle = get_model_bundle(model_key) | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Please upload an image file.") | |
| raw = await file.read() | |
| img_hash = hashlib.sha256(raw).hexdigest() | |
| if bundle["current"]["hash"] == img_hash and bundle["current"]["attention_cls_full"] is not None: | |
| return {"status": "unchanged", "hash": img_hash, "model_key": model_key} | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not decode image.") | |
| try: | |
| out = analyze_image(bundle, img) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") | |
| bundle["current"]["hash"] = img_hash | |
| bundle["current"]["attention_cls_full"] = out["attention_cls_full"] | |
| bundle["current"]["logit_lens_full"] = out["logit_lens_full"] | |
| return {"status": "ok", "hash": img_hash, "model_key": model_key} | |
| def current_attention(model_key: str): | |
| bundle = get_model_bundle(model_key) | |
| if bundle["current"]["attention_cls_full"] is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"No current attention file for '{model_key}'. POST /analyze_current/{model_key} first." | |
| ) | |
| return _no_store(JSONResponse(bundle["current"]["attention_cls_full"])) | |
| def current_logit(model_key: str): | |
| bundle = get_model_bundle(model_key) | |
| if bundle["current"]["logit_lens_full"] is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"No current logit file for '{model_key}'. POST /analyze_current/{model_key} first." | |
| ) | |
| return _no_store(JSONResponse(bundle["current"]["logit_lens_full"])) | |
| # ============================================================ | |
| # Optional backward-compatible aliases for RAD-DINO | |
| # ============================================================ | |
| async def analyze_current_rad_default(file: UploadFile = File(...)): | |
| return await analyze_current("rad-dino", file) | |
| def current_attention_rad_default(): | |
| return current_attention("rad-dino") | |
| def current_logit_rad_default(): | |
| return current_logit("rad-dino") | |
| async def analyze_rad_default( | |
| file: UploadFile = File(...), | |
| store: int = Query(0, description="1 => guarda resultados"), | |
| ): | |
| return await analyze("rad-dino", file, store) | |
| # ============================================================ | |
| # Smoke test | |
| # ============================================================ | |
| def smoke_test_local_image(image_path: str, model_key: str = "rad-dino"): | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"No existe la imagen: {image_path}") | |
| bundle = get_model_bundle(model_key) | |
| img = Image.open(image_path).convert("RGB") | |
| out = analyze_image(bundle, img) | |
| print(f"\n[smoke test] model_key={model_key} OK") | |
| print("[smoke test] capas:", out["attention_cls_full"]["num_layers"]) | |
| print("[smoke test] heads:", out["attention_cls_full"]["num_heads"]) | |
| print("[smoke test] patch tokens:", out["attention_cls_full"]["num_patch_tokens"]) | |
| final_probs = out["logit_lens_full"]["final_probs"] | |
| pairs = sorted(zip(bundle["label_names"], final_probs), key=lambda t: t[1], reverse=True) | |
| print("\nTop-5 predicciones:") | |
| for name, p in pairs[:5]: | |
| print(f" {name:<15} {p:.4f}") | |
| if __name__ == "__main__": | |
| test_path = os.environ.get("TEST_IMAGE_PATH", "").strip() | |
| test_model = os.environ.get("TEST_MODEL_KEY", "rad-dino").strip() | |
| if test_path: | |
| smoke_test_local_image(test_path, test_model) | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |