Spaces:
Sleeping
Sleeping
| import os, json, subprocess, shutil, zipfile | |
| from fastapi import BackgroundTasks, FastAPI, UploadFile, File | |
| from transformers import AutoImageProcessor, BeitForImageClassification | |
| from PIL import Image | |
| import torch | |
| MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina") | |
| DATA_DIR = os.environ.get("DATA_DIR", "data2") | |
| CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"] | |
| app = FastAPI(title="Retina Disease Classifier") | |
| processor = None | |
| model = None | |
| # ---------------------------- | |
| # MODEL LOADING | |
| # ---------------------------- | |
| def load_model(): | |
| global processor, model, CLASSES | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_DIR) | |
| model = BeitForImageClassification.from_pretrained(MODEL_DIR) | |
| labels_path = os.path.join(MODEL_DIR, "labels.json") | |
| if os.path.exists(labels_path): | |
| with open(labels_path) as f: | |
| CLASSES = json.load(f) | |
| print("β Model and processor loaded successfully") | |
| except Exception as e: | |
| processor, model = None, None | |
| print(f"β οΈ Skipping model load: {e}") | |
| # ---------------------------- | |
| # BACKGROUND TRAINING | |
| # ---------------------------- | |
| def run_training(): | |
| try: | |
| print("πΉ Starting training subprocess...") | |
| process = subprocess.Popen( | |
| ["python", "train2.py"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| universal_newlines=True | |
| ) | |
| for line in iter(process.stdout.readline, ""): | |
| print("TRAIN_LOG:", line.strip()) | |
| process.stdout.close() | |
| return_code = process.wait() | |
| if return_code == 0 and os.path.exists(MODEL_DIR): | |
| load_model() | |
| print("β Training complete and model reloaded") | |
| else: | |
| print(f"β Training failed with code {return_code}") | |
| except Exception as e: | |
| print("β οΈ Training exception:", str(e)) | |
| # ---------------------------- | |
| # FASTAPI STARTUP | |
| # ---------------------------- | |
| def startup_event(): | |
| if os.path.exists(MODEL_DIR): | |
| load_model() | |
| else: | |
| print("β οΈ MODEL_DIR not found, skipping model load") | |
| # ---------------------------- | |
| # ENDPOINTS | |
| # ---------------------------- | |
| async def load_data(file: UploadFile = File(...)): | |
| """ | |
| Upload a ZIP file, extract into `data/` folder for training. | |
| """ | |
| print("πΉ Received dataset ZIP upload...") | |
| if os.path.exists(DATA_DIR): | |
| shutil.rmtree(DATA_DIR) | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| zip_path = "dataset.zip" | |
| with open(zip_path, "wb") as f: | |
| f.write(await file.read()) | |
| print(f" βͺ Saved ZIP to {zip_path}") | |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| zip_ref.extractall(DATA_DIR) | |
| print(f"β Dataset extracted to {DATA_DIR}") | |
| os.remove(zip_path) | |
| return {"status": "Dataset uploaded and extracted"} | |
| async def train_endpoint(background_tasks: BackgroundTasks): | |
| background_tasks.add_task(run_training) | |
| return {"status": "Training started in background"} | |
| async def predict(file: UploadFile): | |
| if model is None: | |
| return {"error": "Model not trained yet"} | |
| try: | |
| img = Image.open(file.file).convert("RGB") | |
| except Exception as e: | |
| return {"error": f"Invalid image: {str(e)}"} | |
| inputs = processor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=1)[0].tolist() | |
| pred_id = int(torch.argmax(logits, dim=1).item()) | |
| return { | |
| "class_id": CLASSES[pred_id], | |
| "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)} | |
| } | |