Habeeb Okunade commited on
Commit
05c5199
ยท
1 Parent(s): f119f72

Update the training script

Browse files
Files changed (2) hide show
  1. app.py +54 -15
  2. train2.py +10 -10
app.py CHANGED
@@ -1,19 +1,21 @@
1
- # app.py
2
- import os, json, subprocess
3
- from fastapi import BackgroundTasks, FastAPI, UploadFile
4
  from transformers import AutoImageProcessor, BeitForImageClassification
5
  from PIL import Image
6
  import torch
7
 
8
  MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
 
9
  CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
10
 
11
  app = FastAPI(title="Retina Disease Classifier")
12
 
13
- # Lazy load model & processor
14
  processor = None
15
  model = None
16
 
 
 
 
17
  def load_model():
18
  global processor, model, CLASSES
19
  try:
@@ -28,21 +30,34 @@ def load_model():
28
  processor, model = None, None
29
  print(f"โš ๏ธ Skipping model load: {e}")
30
 
 
 
 
31
  def run_training():
32
  try:
33
- result = subprocess.run(
 
34
  ["python", "train2.py"],
35
- capture_output=True,
36
- text=True
 
37
  )
38
- if result.returncode == 0 and os.path.exists(MODEL_DIR):
 
 
 
 
 
39
  load_model()
40
  print("โœ… Training complete and model reloaded")
41
  else:
42
- print("โŒ Training failed:", result.stderr)
43
  except Exception as e:
44
  print("โš ๏ธ Training exception:", str(e))
45
 
 
 
 
46
  @app.on_event("startup")
47
  def startup_event():
48
  if os.path.exists(MODEL_DIR):
@@ -50,6 +65,36 @@ def startup_event():
50
  else:
51
  print("โš ๏ธ MODEL_DIR not found, skipping model load")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @app.post("/predict")
54
  async def predict(file: UploadFile):
55
  if model is None:
@@ -67,9 +112,3 @@ async def predict(file: UploadFile):
67
  "class_id": CLASSES[pred_id],
68
  "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)}
69
  }
70
-
71
- @app.post("/train")
72
- async def train_endpoint(background_tasks: BackgroundTasks):
73
- # Schedule the training in the background
74
- background_tasks.add_task(run_training)
75
- return {"status": "Training started in background"}
 
1
+ import os, json, subprocess, shutil, zipfile
2
+ from fastapi import BackgroundTasks, FastAPI, UploadFile, File
 
3
  from transformers import AutoImageProcessor, BeitForImageClassification
4
  from PIL import Image
5
  import torch
6
 
7
  MODEL_DIR = os.environ.get("OUTPUT_DIR", "/home/user/outputs/beit-retina")
8
+ DATA_DIR = os.environ.get("DATA_DIR", "data2")
9
  CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
10
 
11
  app = FastAPI(title="Retina Disease Classifier")
12
 
 
13
  processor = None
14
  model = None
15
 
16
+ # ----------------------------
17
+ # MODEL LOADING
18
+ # ----------------------------
19
  def load_model():
20
  global processor, model, CLASSES
21
  try:
 
30
  processor, model = None, None
31
  print(f"โš ๏ธ Skipping model load: {e}")
32
 
33
+ # ----------------------------
34
+ # BACKGROUND TRAINING
35
+ # ----------------------------
36
  def run_training():
37
  try:
38
+ print("๐Ÿ”น Starting training subprocess...")
39
+ process = subprocess.Popen(
40
  ["python", "train2.py"],
41
+ stdout=subprocess.PIPE,
42
+ stderr=subprocess.STDOUT,
43
+ universal_newlines=True
44
  )
45
+ for line in iter(process.stdout.readline, ""):
46
+ print("TRAIN_LOG:", line.strip())
47
+ process.stdout.close()
48
+ return_code = process.wait()
49
+
50
+ if return_code == 0 and os.path.exists(MODEL_DIR):
51
  load_model()
52
  print("โœ… Training complete and model reloaded")
53
  else:
54
+ print(f"โŒ Training failed with code {return_code}")
55
  except Exception as e:
56
  print("โš ๏ธ Training exception:", str(e))
57
 
58
+ # ----------------------------
59
+ # FASTAPI STARTUP
60
+ # ----------------------------
61
  @app.on_event("startup")
62
  def startup_event():
63
  if os.path.exists(MODEL_DIR):
 
65
  else:
66
  print("โš ๏ธ MODEL_DIR not found, skipping model load")
67
 
68
+ # ----------------------------
69
+ # ENDPOINTS
70
+ # ----------------------------
71
+ @app.post("/load-data")
72
+ async def load_data(file: UploadFile = File(...)):
73
+ """
74
+ Upload a ZIP file, extract into `data/` folder for training.
75
+ """
76
+ print("๐Ÿ”น Received dataset ZIP upload...")
77
+ if os.path.exists(DATA_DIR):
78
+ shutil.rmtree(DATA_DIR)
79
+ os.makedirs(DATA_DIR, exist_ok=True)
80
+
81
+ zip_path = "dataset.zip"
82
+ with open(zip_path, "wb") as f:
83
+ f.write(await file.read())
84
+ print(f" โ†ช Saved ZIP to {zip_path}")
85
+
86
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
87
+ zip_ref.extractall(DATA_DIR)
88
+ print(f"โœ… Dataset extracted to {DATA_DIR}")
89
+
90
+ os.remove(zip_path)
91
+ return {"status": "Dataset uploaded and extracted"}
92
+
93
+ @app.post("/train")
94
+ async def train_endpoint(background_tasks: BackgroundTasks):
95
+ background_tasks.add_task(run_training)
96
+ return {"status": "Training started in background"}
97
+
98
  @app.post("/predict")
99
  async def predict(file: UploadFile):
100
  if model is None:
 
112
  "class_id": CLASSES[pred_id],
113
  "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)}
114
  }
 
 
 
 
 
 
train2.py CHANGED
@@ -16,16 +16,17 @@ from PIL import Image
16
  # ----------------------------
17
  MODEL_NAME = "microsoft/beit-base-patch16-224"
18
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
19
- NUM_CLASSES = 6 # retina disease classes
20
 
21
  print(f"๐Ÿ”น OUTPUT_DIR set to: {OUTPUT_DIR}")
 
22
  os.makedirs(OUTPUT_DIR, exist_ok=True)
23
 
24
  # ----------------------------
25
  # LOAD DATASET
26
  # ----------------------------
27
- print("๐Ÿ”น Loading dataset from 'data/' folder...")
28
- dataset = load_dataset("imagefolder", data_dir="data")
29
  print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
30
  print(f"๐Ÿ”น Dataset splits: {list(dataset.keys())}")
31
  print(f"๐Ÿ”น Number of training samples: {len(dataset['train'])}")
@@ -38,29 +39,27 @@ print(f"๐Ÿ”น Loading processor from {MODEL_NAME}...")
38
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
39
 
40
  def transform(example):
41
- # Determine correct image column
42
  image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0]
43
-
44
  images = example[image_column]
45
 
46
- # Ensure we always have a list
47
  if not isinstance(images, list):
48
  images = [images]
49
 
50
  processed_images = []
51
  for img in images:
52
  if isinstance(img, str):
 
53
  img = Image.open(img).convert("RGB")
54
  elif isinstance(img, Image.Image):
 
55
  img = img.convert("RGB")
56
  else:
57
  raise ValueError(f"Unknown type for image: {type(img)}")
58
  processed_images.append(img)
59
 
60
- # Convert to tensors (batched)
61
  inputs = processor(images=processed_images, return_tensors="pt")
62
 
63
- # Handle labels
64
  labels = example["label"]
65
  if not isinstance(labels, list):
66
  labels = [labels]
@@ -75,10 +74,10 @@ print("๐Ÿ”น Transform applied successfully.")
75
  # ----------------------------
76
  # MODEL
77
  # ----------------------------
78
- print(f"๐Ÿ”น Loading BEiT model ({MODEL_NAME}) with {NUM_CLASSES} classes...")
79
  model = BeitForImageClassification.from_pretrained(
80
  MODEL_NAME,
81
- num_labels=NUM_CLASSES,
82
  ignore_mismatched_sizes=True
83
  )
84
  print("๐Ÿ”น Model loaded successfully.")
@@ -111,6 +110,7 @@ args = TrainingArguments(
111
  num_train_epochs=5,
112
  weight_decay=0.01,
113
  logging_dir=os.path.join(OUTPUT_DIR, "logs"),
 
114
  push_to_hub=False
115
  )
116
  print("๐Ÿ”น TrainingArguments configured.")
 
16
  # ----------------------------
17
  MODEL_NAME = "microsoft/beit-base-patch16-224"
18
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
19
+ DATA_DIR = os.environ.get("DATA_DIR", "data2") # dynamic dataset path
20
 
21
  print(f"๐Ÿ”น OUTPUT_DIR set to: {OUTPUT_DIR}")
22
+ print(f"๐Ÿ”น DATA_DIR set to: {DATA_DIR}")
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
 
25
  # ----------------------------
26
  # LOAD DATASET
27
  # ----------------------------
28
+ print(f"๐Ÿ”น Loading dataset from '{DATA_DIR}' folder...")
29
+ dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
30
  print(f"๐Ÿ”น Dataset loaded. Columns: {dataset['train'].column_names}")
31
  print(f"๐Ÿ”น Dataset splits: {list(dataset.keys())}")
32
  print(f"๐Ÿ”น Number of training samples: {len(dataset['train'])}")
 
39
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
40
 
41
  def transform(example):
42
+ # Detect image column
43
  image_column = "image" if "image" in example else [c for c in example.keys() if c != "label"][0]
 
44
  images = example[image_column]
45
 
 
46
  if not isinstance(images, list):
47
  images = [images]
48
 
49
  processed_images = []
50
  for img in images:
51
  if isinstance(img, str):
52
+ print(f" โ†ช Opening image from path: {img}")
53
  img = Image.open(img).convert("RGB")
54
  elif isinstance(img, Image.Image):
55
+ print(" โ†ช Using PIL.Image directly")
56
  img = img.convert("RGB")
57
  else:
58
  raise ValueError(f"Unknown type for image: {type(img)}")
59
  processed_images.append(img)
60
 
 
61
  inputs = processor(images=processed_images, return_tensors="pt")
62
 
 
63
  labels = example["label"]
64
  if not isinstance(labels, list):
65
  labels = [labels]
 
74
  # ----------------------------
75
  # MODEL
76
  # ----------------------------
77
+ print(f"๐Ÿ”น Loading BEiT model ({MODEL_NAME}) with {len(dataset['train'].features['label'].names)} classes...")
78
  model = BeitForImageClassification.from_pretrained(
79
  MODEL_NAME,
80
+ num_labels=len(dataset["train"].features["label"].names),
81
  ignore_mismatched_sizes=True
82
  )
83
  print("๐Ÿ”น Model loaded successfully.")
 
110
  num_train_epochs=5,
111
  weight_decay=0.01,
112
  logging_dir=os.path.join(OUTPUT_DIR, "logs"),
113
+ logging_steps=10,
114
  push_to_hub=False
115
  )
116
  print("๐Ÿ”น TrainingArguments configured.")