import os, json, subprocess, shutil, zipfile from fastapi import BackgroundTasks, FastAPI, UploadFile, File from transformers import AutoImageProcessor, BeitForImageClassification from PIL import Image import torch MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina") DATA_DIR = os.environ.get("DATA_DIR", "data2") CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"] app = FastAPI(title="Retina Disease Classifier") processor = None model = None # ---------------------------- # MODEL LOADING # ---------------------------- def load_model(): global processor, model, CLASSES try: processor = AutoImageProcessor.from_pretrained(MODEL_DIR) model = BeitForImageClassification.from_pretrained(MODEL_DIR) labels_path = os.path.join(MODEL_DIR, "labels.json") if os.path.exists(labels_path): with open(labels_path) as f: CLASSES = json.load(f) print("✅ Model and processor loaded successfully") except Exception as e: processor, model = None, None print(f"⚠️ Skipping model load: {e}") # ---------------------------- # BACKGROUND TRAINING # ---------------------------- def run_training(): try: print("🔹 Starting training subprocess...") process = subprocess.Popen( ["python", "train2.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True ) for line in iter(process.stdout.readline, ""): print("TRAIN_LOG:", line.strip()) process.stdout.close() return_code = process.wait() if return_code == 0 and os.path.exists(MODEL_DIR): load_model() print("✅ Training complete and model reloaded") else: print(f"❌ Training failed with code {return_code}") except Exception as e: print("⚠️ Training exception:", str(e)) # ---------------------------- # FASTAPI STARTUP # ---------------------------- @app.on_event("startup") def startup_event(): if os.path.exists(MODEL_DIR): load_model() else: print("⚠️ MODEL_DIR not found, skipping model load") # ---------------------------- # ENDPOINTS # ---------------------------- @app.post("/load-data") async def load_data(file: UploadFile = File(...)): """ Upload a ZIP file, extract into `data/` folder for training. """ print("🔹 Received dataset ZIP upload...") if os.path.exists(DATA_DIR): shutil.rmtree(DATA_DIR) os.makedirs(DATA_DIR, exist_ok=True) zip_path = "dataset.zip" with open(zip_path, "wb") as f: f.write(await file.read()) print(f" ↪ Saved ZIP to {zip_path}") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(DATA_DIR) print(f"✅ Dataset extracted to {DATA_DIR}") os.remove(zip_path) return {"status": "Dataset uploaded and extracted"} @app.post("/train") async def train_endpoint(background_tasks: BackgroundTasks): background_tasks.add_task(run_training) return {"status": "Training started in background"} @app.post("/predict") async def predict(file: UploadFile): if model is None: return {"error": "Model not trained yet"} try: img = Image.open(file.file).convert("RGB") except Exception as e: return {"error": f"Invalid image: {str(e)}"} inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1)[0].tolist() pred_id = int(torch.argmax(logits, dim=1).item()) return { "class_id": CLASSES[pred_id], "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)} }