Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import io | |
| import os | |
| from threading import Thread | |
| from typing import Optional | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| from huggingface_hub import HfApi, create_repo, upload_file | |
| from train import fit | |
| from model_loader import build_classifier | |
| from dataset import NORM_MEAN, NORM_STD | |
| from utils import unzip_dataset, clean_dir | |
| app = FastAPI(title="RETFound MAE – Train & Inference API") | |
| # ---------- Config (safe paths for HF Spaces) ---------- | |
| DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") | |
| CKPT_DIR = os.getenv("CKPT_DIR", "/tmp/checkpoints") | |
| BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh") | |
| BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth") | |
| MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Runtime state | |
| _state = { | |
| "training": False, | |
| "best_ckpt": None, | |
| "classes": None, | |
| "val_acc": None, | |
| } | |
| _model = None | |
| _transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=NORM_MEAN, std=NORM_STD) | |
| ]) | |
| def _load_model_for_inference(): | |
| global _model | |
| if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]): | |
| ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE) | |
| classes = ckpt.get("classes", []) | |
| _state["classes"] = classes | |
| model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE) | |
| model.load_state_dict(ckpt["model"], strict=False) | |
| else: | |
| model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE) | |
| model.eval() | |
| _model = model | |
| def status(): | |
| return { | |
| "training": _state["training"], | |
| "best_ckpt": _state["best_ckpt"], | |
| "classes": _state["classes"], | |
| "val_acc": _state["val_acc"], | |
| "device": DEVICE, | |
| } | |
| async def upload_dataset(file: UploadFile = File(...)): | |
| """Upload a ZIP that contains train/ and val/ folders.""" | |
| os.makedirs("/tmp/uploads", exist_ok=True) | |
| zip_path = f"/tmp/uploads/{file.filename}" | |
| with open(zip_path, "wb") as f: | |
| f.write(await file.read()) | |
| clean_dir(DATA_ROOT) # Now points to /tmp/data | |
| extracted = unzip_dataset(zip_path, DATA_ROOT) | |
| return {"ok": True, "dataset_dir": extracted} | |
| async def start_train(epochs: int = Form(10), batch_size: int = Form(16), lr: float = Form(5e-4), freeze_backbone: bool = Form(True)): | |
| if _state["training"]: | |
| return JSONResponse({"error": "Training already in progress"}, status_code=409) | |
| def _run(): | |
| try: | |
| _state["training"] = True | |
| ckpt_path, classes, best_acc = fit( | |
| data_root=DATA_ROOT, | |
| base_repo=BASE_REPO, | |
| base_filename=BASE_FILE, | |
| epochs=int(epochs), | |
| batch_size=int(batch_size), | |
| lr=float(lr), | |
| freeze_backbone=bool(freeze_backbone), | |
| out_dir=CKPT_DIR, | |
| device=DEVICE, | |
| ) | |
| _state.update({ | |
| "best_ckpt": ckpt_path, | |
| "classes": classes, | |
| "val_acc": best_acc, | |
| }) | |
| _load_model_for_inference() | |
| finally: | |
| _state["training"] = False | |
| Thread(target=_run, daemon=True).start() | |
| return {"ok": True, "message": "Training started", "params": {"epochs": epochs, "batch_size": batch_size, "lr": lr, "freeze_backbone": freeze_backbone}} | |
| async def predict(file: UploadFile = File(...)): | |
| global _model | |
| if _model is None: | |
| _load_model_for_inference() | |
| if _state["classes"] is None: | |
| return JSONResponse({"error": "Model not trained yet. Upload dataset and call /train first."}, status_code=400) | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| x = _transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = _model(x) | |
| probs = torch.softmax(logits, dim=1)[0].cpu().tolist() | |
| idx = int(torch.argmax(torch.tensor(probs)).item()) | |
| return { | |
| "prediction": _state["classes"][idx], | |
| "probabilities": {cls: float(probs[i]) for i, cls in enumerate(_state["classes"])}, | |
| } | |
| async def push_to_hub(repo_id: Optional[str] = Form(None)): | |
| """Push best checkpoint + metadata to Hugging Face Hub.""" | |
| repo_id = repo_id or MODEL_PUSH_REPO | |
| if not repo_id: | |
| return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400) | |
| if not _state["best_ckpt"] or not os.path.exists(_state["best_ckpt"]): | |
| return JSONResponse({"error": "No trained checkpoint to push."}, status_code=400) | |
| api = HfApi() | |
| create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) | |
| upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model") | |
| card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned) | |
| - Base: `{BASE_REPO}/{BASE_FILE}` | |
| - Classes: `{_state['classes']}` | |
| - Best val acc: `{_state['val_acc']}` | |
| ## Inference | |
| This repo contains a PyTorch checkpoint `retfound_classifier_best.pth` compatible with RETFound MAE backbone. | |
| """ | |
| card_path = "/tmp/MODEL_CARD.md" | |
| with open(card_path, "w") as f: | |
| f.write(card) | |
| upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model") | |
| return {"ok": True, "pushed_to": repo_id} | |