Spaces:
Sleeping
Sleeping
Habeeb Okunade
commited on
Commit
·
aa253e2
1
Parent(s):
d8770e8
Updating model
Browse files
api.py
CHANGED
|
@@ -18,9 +18,9 @@ from utils import unzip_dataset, clean_dir
|
|
| 18 |
|
| 19 |
app = FastAPI(title="RETFound MAE – Train & Inference API")
|
| 20 |
|
| 21 |
-
# ---------- Config (
|
| 22 |
-
DATA_ROOT = os.getenv("DATA_ROOT", "
|
| 23 |
-
CKPT_DIR = os.getenv("CKPT_DIR", "
|
| 24 |
BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh")
|
| 25 |
BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth")
|
| 26 |
MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1")
|
|
@@ -44,7 +44,6 @@ _transform = transforms.Compose([
|
|
| 44 |
|
| 45 |
def _load_model_for_inference():
|
| 46 |
global _model
|
| 47 |
-
# If we already trained and have a best checkpoint, load it; else build unfitted head
|
| 48 |
if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]):
|
| 49 |
ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE)
|
| 50 |
classes = ckpt.get("classes", [])
|
|
@@ -52,7 +51,6 @@ def _load_model_for_inference():
|
|
| 52 |
model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
|
| 53 |
model.load_state_dict(ckpt["model"], strict=False)
|
| 54 |
else:
|
| 55 |
-
# Cold start with 2 placeholder classes to allow /predict errors to be informative
|
| 56 |
model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
|
| 57 |
model.eval()
|
| 58 |
_model = model
|
|
@@ -71,17 +69,13 @@ def status():
|
|
| 71 |
|
| 72 |
@app.post("/upload_dataset")
|
| 73 |
async def upload_dataset(file: UploadFile = File(...)):
|
| 74 |
-
"""Upload a ZIP that contains train/ and val/ folders.
|
| 75 |
-
Example structure inside zip:
|
| 76 |
-
train/class_a/*.jpg, val/class_a/*.jpg
|
| 77 |
-
"""
|
| 78 |
os.makedirs("/tmp/uploads", exist_ok=True)
|
| 79 |
zip_path = f"/tmp/uploads/{file.filename}"
|
| 80 |
with open(zip_path, "wb") as f:
|
| 81 |
f.write(await file.read())
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
clean_dir(DATA_ROOT)
|
| 85 |
extracted = unzip_dataset(zip_path, DATA_ROOT)
|
| 86 |
return {"ok": True, "dataset_dir": extracted}
|
| 87 |
|
|
@@ -141,9 +135,7 @@ async def predict(file: UploadFile = File(...)):
|
|
| 141 |
|
| 142 |
@app.post("/push")
|
| 143 |
async def push_to_hub(repo_id: Optional[str] = Form(None)):
|
| 144 |
-
"""Push best checkpoint + metadata to Hugging Face Hub
|
| 145 |
-
Requires write permission via HF token (Spaces secret HF_TOKEN is used automatically).
|
| 146 |
-
"""
|
| 147 |
repo_id = repo_id or MODEL_PUSH_REPO
|
| 148 |
if not repo_id:
|
| 149 |
return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400)
|
|
@@ -153,10 +145,8 @@ async def push_to_hub(repo_id: Optional[str] = Form(None)):
|
|
| 153 |
api = HfApi()
|
| 154 |
create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
| 155 |
|
| 156 |
-
# Upload files
|
| 157 |
upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model")
|
| 158 |
|
| 159 |
-
# Write simple model card
|
| 160 |
card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned)
|
| 161 |
|
| 162 |
- Base: `{BASE_REPO}/{BASE_FILE}`
|
|
@@ -171,4 +161,4 @@ This repo contains a PyTorch checkpoint `retfound_classifier_best.pth` compatibl
|
|
| 171 |
f.write(card)
|
| 172 |
upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model")
|
| 173 |
|
| 174 |
-
return {"ok": True, "pushed_to": repo_id}
|
|
|
|
| 18 |
|
| 19 |
app = FastAPI(title="RETFound MAE – Train & Inference API")
|
| 20 |
|
| 21 |
+
# ---------- Config (safe paths for HF Spaces) ----------
|
| 22 |
+
DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data")
|
| 23 |
+
CKPT_DIR = os.getenv("CKPT_DIR", "/tmp/checkpoints")
|
| 24 |
BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh")
|
| 25 |
BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth")
|
| 26 |
MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1")
|
|
|
|
| 44 |
|
| 45 |
def _load_model_for_inference():
|
| 46 |
global _model
|
|
|
|
| 47 |
if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]):
|
| 48 |
ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE)
|
| 49 |
classes = ckpt.get("classes", [])
|
|
|
|
| 51 |
model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
|
| 52 |
model.load_state_dict(ckpt["model"], strict=False)
|
| 53 |
else:
|
|
|
|
| 54 |
model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
|
| 55 |
model.eval()
|
| 56 |
_model = model
|
|
|
|
| 69 |
|
| 70 |
@app.post("/upload_dataset")
|
| 71 |
async def upload_dataset(file: UploadFile = File(...)):
|
| 72 |
+
"""Upload a ZIP that contains train/ and val/ folders."""
|
|
|
|
|
|
|
|
|
|
| 73 |
os.makedirs("/tmp/uploads", exist_ok=True)
|
| 74 |
zip_path = f"/tmp/uploads/{file.filename}"
|
| 75 |
with open(zip_path, "wb") as f:
|
| 76 |
f.write(await file.read())
|
| 77 |
|
| 78 |
+
clean_dir(DATA_ROOT) # Now points to /tmp/data
|
|
|
|
| 79 |
extracted = unzip_dataset(zip_path, DATA_ROOT)
|
| 80 |
return {"ok": True, "dataset_dir": extracted}
|
| 81 |
|
|
|
|
| 135 |
|
| 136 |
@app.post("/push")
|
| 137 |
async def push_to_hub(repo_id: Optional[str] = Form(None)):
|
| 138 |
+
"""Push best checkpoint + metadata to Hugging Face Hub."""
|
|
|
|
|
|
|
| 139 |
repo_id = repo_id or MODEL_PUSH_REPO
|
| 140 |
if not repo_id:
|
| 141 |
return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400)
|
|
|
|
| 145 |
api = HfApi()
|
| 146 |
create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
| 147 |
|
|
|
|
| 148 |
upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model")
|
| 149 |
|
|
|
|
| 150 |
card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned)
|
| 151 |
|
| 152 |
- Base: `{BASE_REPO}/{BASE_FILE}`
|
|
|
|
| 161 |
f.write(card)
|
| 162 |
upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model")
|
| 163 |
|
| 164 |
+
return {"ok": True, "pushed_to": repo_id}
|