File size: 5,801 Bytes
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa253e2
 
 
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa253e2
39ec591
 
 
 
 
aa253e2
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa253e2
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa253e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations
import io
import os
from threading import Thread
from typing import Optional

from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms
from huggingface_hub import HfApi, create_repo, upload_file

from train import fit
from model_loader import build_classifier
from dataset import NORM_MEAN, NORM_STD
from utils import unzip_dataset, clean_dir

app = FastAPI(title="RETFound MAE – Train & Inference API")

# ---------- Config (safe paths for HF Spaces) ----------
DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data")
CKPT_DIR  = os.getenv("CKPT_DIR",  "/tmp/checkpoints")
BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh")
BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth")
MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Runtime state
_state = {
    "training": False,
    "best_ckpt": None,
    "classes": None,
    "val_acc": None,
}
_model = None
_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD)
])


def _load_model_for_inference():
    global _model
    if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]):
        ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE)
        classes = ckpt.get("classes", [])
        _state["classes"] = classes
        model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
        model.load_state_dict(ckpt["model"], strict=False)
    else:
        model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
    model.eval()
    _model = model


@app.get("/status")
def status():
    return {
        "training": _state["training"],
        "best_ckpt": _state["best_ckpt"],
        "classes": _state["classes"],
        "val_acc": _state["val_acc"],
        "device": DEVICE,
    }


@app.post("/upload_dataset")
async def upload_dataset(file: UploadFile = File(...)):
    """Upload a ZIP that contains train/ and val/ folders."""
    os.makedirs("/tmp/uploads", exist_ok=True)
    zip_path = f"/tmp/uploads/{file.filename}"
    with open(zip_path, "wb") as f:
        f.write(await file.read())

    clean_dir(DATA_ROOT)  # Now points to /tmp/data
    extracted = unzip_dataset(zip_path, DATA_ROOT)
    return {"ok": True, "dataset_dir": extracted}


@app.post("/train")
async def start_train(epochs: int = Form(10), batch_size: int = Form(16), lr: float = Form(5e-4), freeze_backbone: bool = Form(True)):
    if _state["training"]:
        return JSONResponse({"error": "Training already in progress"}, status_code=409)

    def _run():
        try:
            _state["training"] = True
            ckpt_path, classes, best_acc = fit(
                data_root=DATA_ROOT,
                base_repo=BASE_REPO,
                base_filename=BASE_FILE,
                epochs=int(epochs),
                batch_size=int(batch_size),
                lr=float(lr),
                freeze_backbone=bool(freeze_backbone),
                out_dir=CKPT_DIR,
                device=DEVICE,
            )
            _state.update({
                "best_ckpt": ckpt_path,
                "classes": classes,
                "val_acc": best_acc,
            })
            _load_model_for_inference()
        finally:
            _state["training"] = False

    Thread(target=_run, daemon=True).start()
    return {"ok": True, "message": "Training started", "params": {"epochs": epochs, "batch_size": batch_size, "lr": lr, "freeze_backbone": freeze_backbone}}


@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    global _model
    if _model is None:
        _load_model_for_inference()
    if _state["classes"] is None:
        return JSONResponse({"error": "Model not trained yet. Upload dataset and call /train first."}, status_code=400)

    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    x = _transform(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = _model(x)
        probs = torch.softmax(logits, dim=1)[0].cpu().tolist()
    idx = int(torch.argmax(torch.tensor(probs)).item())
    return {
        "prediction": _state["classes"][idx],
        "probabilities": {cls: float(probs[i]) for i, cls in enumerate(_state["classes"])},
    }


@app.post("/push")
async def push_to_hub(repo_id: Optional[str] = Form(None)):
    """Push best checkpoint + metadata to Hugging Face Hub."""
    repo_id = repo_id or MODEL_PUSH_REPO
    if not repo_id:
        return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400)
    if not _state["best_ckpt"] or not os.path.exists(_state["best_ckpt"]):
        return JSONResponse({"error": "No trained checkpoint to push."}, status_code=400)

    api = HfApi()
    create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)

    upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model")

    card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned)

- Base: `{BASE_REPO}/{BASE_FILE}`
- Classes: `{_state['classes']}`
- Best val acc: `{_state['val_acc']}`

## Inference
This repo contains a PyTorch checkpoint `retfound_classifier_best.pth` compatible with RETFound MAE backbone.
"""
    card_path = "/tmp/MODEL_CARD.md"
    with open(card_path, "w") as f:
        f.write(card)
    upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model")

    return {"ok": True, "pushed_to": repo_id}