Spaces:
Sleeping
Sleeping
File size: 3,881 Bytes
05c5199 96c3348 472db94 05c5199 96c3348 05c5199 96c3348 cb24c7c 96c3348 05c5199 0e0e505 05c5199 0e0e505 05c5199 0e0e505 05c5199 0e0e505 05c5199 0e0e505 05c5199 96c3348 cb24c7c 96c3348 05c5199 96c3348 cb24c7c 96c3348 cb24c7c 96c3348 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os, json, subprocess, shutil, zipfile
from fastapi import BackgroundTasks, FastAPI, UploadFile, File
from transformers import AutoImageProcessor, BeitForImageClassification
from PIL import Image
import torch
MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
DATA_DIR = os.environ.get("DATA_DIR", "data2")
CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
app = FastAPI(title="Retina Disease Classifier")
processor = None
model = None
# ----------------------------
# MODEL LOADING
# ----------------------------
def load_model():
global processor, model, CLASSES
try:
processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
model = BeitForImageClassification.from_pretrained(MODEL_DIR)
labels_path = os.path.join(MODEL_DIR, "labels.json")
if os.path.exists(labels_path):
with open(labels_path) as f:
CLASSES = json.load(f)
print("โ
Model and processor loaded successfully")
except Exception as e:
processor, model = None, None
print(f"โ ๏ธ Skipping model load: {e}")
# ----------------------------
# BACKGROUND TRAINING
# ----------------------------
def run_training():
try:
print("๐น Starting training subprocess...")
process = subprocess.Popen(
["python", "train2.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True
)
for line in iter(process.stdout.readline, ""):
print("TRAIN_LOG:", line.strip())
process.stdout.close()
return_code = process.wait()
if return_code == 0 and os.path.exists(MODEL_DIR):
load_model()
print("โ
Training complete and model reloaded")
else:
print(f"โ Training failed with code {return_code}")
except Exception as e:
print("โ ๏ธ Training exception:", str(e))
# ----------------------------
# FASTAPI STARTUP
# ----------------------------
@app.on_event("startup")
def startup_event():
if os.path.exists(MODEL_DIR):
load_model()
else:
print("โ ๏ธ MODEL_DIR not found, skipping model load")
# ----------------------------
# ENDPOINTS
# ----------------------------
@app.post("/load-data")
async def load_data(file: UploadFile = File(...)):
"""
Upload a ZIP file, extract into `data/` folder for training.
"""
print("๐น Received dataset ZIP upload...")
if os.path.exists(DATA_DIR):
shutil.rmtree(DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)
zip_path = "dataset.zip"
with open(zip_path, "wb") as f:
f.write(await file.read())
print(f" โช Saved ZIP to {zip_path}")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(DATA_DIR)
print(f"โ
Dataset extracted to {DATA_DIR}")
os.remove(zip_path)
return {"status": "Dataset uploaded and extracted"}
@app.post("/train")
async def train_endpoint(background_tasks: BackgroundTasks):
background_tasks.add_task(run_training)
return {"status": "Training started in background"}
@app.post("/predict")
async def predict(file: UploadFile):
if model is None:
return {"error": "Model not trained yet"}
try:
img = Image.open(file.file).convert("RGB")
except Exception as e:
return {"error": f"Invalid image: {str(e)}"}
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].tolist()
pred_id = int(torch.argmax(logits, dim=1).item())
return {
"class_id": CLASSES[pred_id],
"probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)}
}
|