Habeeb Okunade commited on
Commit
aa253e2
·
1 Parent(s): d8770e8

Updating model

Browse files
Files changed (1) hide show
  1. api.py +7 -17
api.py CHANGED
@@ -18,9 +18,9 @@ from utils import unzip_dataset, clean_dir
18
 
19
  app = FastAPI(title="RETFound MAE – Train & Inference API")
20
 
21
- # ---------- Config (env vars for Spaces) ----------
22
- DATA_ROOT = os.getenv("DATA_ROOT", "./workspace/data")
23
- CKPT_DIR = os.getenv("CKPT_DIR", "./workspace/checkpoints")
24
  BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh")
25
  BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth")
26
  MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1")
@@ -44,7 +44,6 @@ _transform = transforms.Compose([
44
 
45
  def _load_model_for_inference():
46
  global _model
47
- # If we already trained and have a best checkpoint, load it; else build unfitted head
48
  if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]):
49
  ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE)
50
  classes = ckpt.get("classes", [])
@@ -52,7 +51,6 @@ def _load_model_for_inference():
52
  model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
53
  model.load_state_dict(ckpt["model"], strict=False)
54
  else:
55
- # Cold start with 2 placeholder classes to allow /predict errors to be informative
56
  model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
57
  model.eval()
58
  _model = model
@@ -71,17 +69,13 @@ def status():
71
 
72
  @app.post("/upload_dataset")
73
  async def upload_dataset(file: UploadFile = File(...)):
74
- """Upload a ZIP that contains train/ and val/ folders.
75
- Example structure inside zip:
76
- train/class_a/*.jpg, val/class_a/*.jpg
77
- """
78
  os.makedirs("/tmp/uploads", exist_ok=True)
79
  zip_path = f"/tmp/uploads/{file.filename}"
80
  with open(zip_path, "wb") as f:
81
  f.write(await file.read())
82
 
83
- # Clean previous dataset and extract
84
- clean_dir(DATA_ROOT)
85
  extracted = unzip_dataset(zip_path, DATA_ROOT)
86
  return {"ok": True, "dataset_dir": extracted}
87
 
@@ -141,9 +135,7 @@ async def predict(file: UploadFile = File(...)):
141
 
142
  @app.post("/push")
143
  async def push_to_hub(repo_id: Optional[str] = Form(None)):
144
- """Push best checkpoint + metadata to Hugging Face Hub model repo.
145
- Requires write permission via HF token (Spaces secret HF_TOKEN is used automatically).
146
- """
147
  repo_id = repo_id or MODEL_PUSH_REPO
148
  if not repo_id:
149
  return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400)
@@ -153,10 +145,8 @@ async def push_to_hub(repo_id: Optional[str] = Form(None)):
153
  api = HfApi()
154
  create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
155
 
156
- # Upload files
157
  upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model")
158
 
159
- # Write simple model card
160
  card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned)
161
 
162
  - Base: `{BASE_REPO}/{BASE_FILE}`
@@ -171,4 +161,4 @@ This repo contains a PyTorch checkpoint `retfound_classifier_best.pth` compatibl
171
  f.write(card)
172
  upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model")
173
 
174
- return {"ok": True, "pushed_to": repo_id}
 
18
 
19
  app = FastAPI(title="RETFound MAE – Train & Inference API")
20
 
21
+ # ---------- Config (safe paths for HF Spaces) ----------
22
+ DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data")
23
+ CKPT_DIR = os.getenv("CKPT_DIR", "/tmp/checkpoints")
24
  BASE_REPO = os.getenv("HF_BASE_MODEL_REPO", "YukunZhou/RETFound_mae_meh")
25
  BASE_FILE = os.getenv("HF_BASE_MODEL_FILE", "RETFound_mae_meh.pth")
26
  MODEL_PUSH_REPO = os.getenv("HF_PUSH_REPO", "habeebCycle/RETFound_mae_meh_1")
 
44
 
45
  def _load_model_for_inference():
46
  global _model
 
47
  if _state["best_ckpt"] and os.path.exists(_state["best_ckpt"]):
48
  ckpt = torch.load(_state["best_ckpt"], map_location=DEVICE)
49
  classes = ckpt.get("classes", [])
 
51
  model = build_classifier(num_classes=len(classes), base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
52
  model.load_state_dict(ckpt["model"], strict=False)
53
  else:
 
54
  model = build_classifier(num_classes=2, base_repo=BASE_REPO, base_filename=BASE_FILE, device=DEVICE)
55
  model.eval()
56
  _model = model
 
69
 
70
  @app.post("/upload_dataset")
71
  async def upload_dataset(file: UploadFile = File(...)):
72
+ """Upload a ZIP that contains train/ and val/ folders."""
 
 
 
73
  os.makedirs("/tmp/uploads", exist_ok=True)
74
  zip_path = f"/tmp/uploads/{file.filename}"
75
  with open(zip_path, "wb") as f:
76
  f.write(await file.read())
77
 
78
+ clean_dir(DATA_ROOT) # Now points to /tmp/data
 
79
  extracted = unzip_dataset(zip_path, DATA_ROOT)
80
  return {"ok": True, "dataset_dir": extracted}
81
 
 
135
 
136
  @app.post("/push")
137
  async def push_to_hub(repo_id: Optional[str] = Form(None)):
138
+ """Push best checkpoint + metadata to Hugging Face Hub."""
 
 
139
  repo_id = repo_id or MODEL_PUSH_REPO
140
  if not repo_id:
141
  return JSONResponse({"error": "Set HF_PUSH_REPO env var or pass repo_id."}, status_code=400)
 
145
  api = HfApi()
146
  create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
147
 
 
148
  upload_file(path_or_fileobj=_state["best_ckpt"], path_in_repo="retfound_classifier_best.pth", repo_id=repo_id, repo_type="model")
149
 
 
150
  card = f"""# RETFound MAE – Retinal Classifier (Fine-tuned)
151
 
152
  - Base: `{BASE_REPO}/{BASE_FILE}`
 
161
  f.write(card)
162
  upload_file(path_or_fileobj=card_path, path_in_repo="README.md", repo_id=repo_id, repo_type="model")
163
 
164
+ return {"ok": True, "pushed_to": repo_id}