File size: 3,881 Bytes
05c5199
 
96c3348
 
 
 
472db94
05c5199
96c3348
 
 
 
 
 
 
05c5199
 
 
96c3348
 
cb24c7c
 
 
 
 
 
 
 
 
 
 
96c3348
05c5199
 
 
0e0e505
 
05c5199
 
0e0e505
05c5199
 
 
0e0e505
05c5199
 
 
 
 
 
0e0e505
 
 
05c5199
0e0e505
 
 
05c5199
 
 
96c3348
 
 
 
cb24c7c
 
96c3348
05c5199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c3348
 
 
 
cb24c7c
 
 
 
96c3348
 
 
 
 
 
 
cb24c7c
96c3348
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os, json, subprocess, shutil, zipfile
from fastapi import BackgroundTasks, FastAPI, UploadFile, File
from transformers import AutoImageProcessor, BeitForImageClassification
from PIL import Image
import torch

MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
DATA_DIR = os.environ.get("DATA_DIR", "data2")
CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]

app = FastAPI(title="Retina Disease Classifier")

processor = None
model = None

# ----------------------------
# MODEL LOADING
# ----------------------------
def load_model():
    global processor, model, CLASSES
    try:
        processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
        model = BeitForImageClassification.from_pretrained(MODEL_DIR)
        labels_path = os.path.join(MODEL_DIR, "labels.json")
        if os.path.exists(labels_path):
            with open(labels_path) as f:
                CLASSES = json.load(f)
        print("โœ… Model and processor loaded successfully")
    except Exception as e:
        processor, model = None, None
        print(f"โš ๏ธ Skipping model load: {e}")

# ----------------------------
# BACKGROUND TRAINING
# ----------------------------
def run_training():
    try:
        print("๐Ÿ”น Starting training subprocess...")
        process = subprocess.Popen(
            ["python", "train2.py"],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True
        )
        for line in iter(process.stdout.readline, ""):
            print("TRAIN_LOG:", line.strip())
        process.stdout.close()
        return_code = process.wait()

        if return_code == 0 and os.path.exists(MODEL_DIR):
            load_model()
            print("โœ… Training complete and model reloaded")
        else:
            print(f"โŒ Training failed with code {return_code}")
    except Exception as e:
        print("โš ๏ธ Training exception:", str(e))

# ----------------------------
# FASTAPI STARTUP
# ----------------------------
@app.on_event("startup")
def startup_event():
    if os.path.exists(MODEL_DIR):
        load_model()
    else:
        print("โš ๏ธ MODEL_DIR not found, skipping model load")

# ----------------------------
# ENDPOINTS
# ----------------------------
@app.post("/load-data")
async def load_data(file: UploadFile = File(...)):
    """

    Upload a ZIP file, extract into `data/` folder for training.

    """
    print("๐Ÿ”น Received dataset ZIP upload...")
    if os.path.exists(DATA_DIR):
        shutil.rmtree(DATA_DIR)
    os.makedirs(DATA_DIR, exist_ok=True)

    zip_path = "dataset.zip"
    with open(zip_path, "wb") as f:
        f.write(await file.read())
    print(f"   โ†ช Saved ZIP to {zip_path}")

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(DATA_DIR)
    print(f"โœ… Dataset extracted to {DATA_DIR}")

    os.remove(zip_path)
    return {"status": "Dataset uploaded and extracted"}

@app.post("/train")
async def train_endpoint(background_tasks: BackgroundTasks):
    background_tasks.add_task(run_training)
    return {"status": "Training started in background"}

@app.post("/predict")
async def predict(file: UploadFile):
    if model is None:
        return {"error": "Model not trained yet"}
    try:
        img = Image.open(file.file).convert("RGB")
    except Exception as e:
        return {"error": f"Invalid image: {str(e)}"}
    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=1)[0].tolist()
    pred_id = int(torch.argmax(logits, dim=1).item())
    return {
        "class_id": CLASSES[pred_id],
        "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)}
    }