Habeeb Okunade commited on
Commit
0e0e505
·
1 Parent(s): cb24c7c

Update Training script

Browse files
Files changed (2) hide show
  1. app.py +21 -6
  2. train2.py +110 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # app.py
2
- import os, json
3
- from fastapi import FastAPI, UploadFile
4
  from transformers import AutoImageProcessor, BeitForImageClassification
5
  from PIL import Image
6
  import torch
@@ -28,6 +28,21 @@ def load_model():
28
  processor, model = None, None
29
  print(f"⚠️ Skipping model load: {e}")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.on_event("startup")
32
  def startup_event():
33
  if os.path.exists(MODEL_DIR):
@@ -54,7 +69,7 @@ async def predict(file: UploadFile):
54
  }
55
 
56
  @app.post("/train")
57
- async def train_endpoint():
58
- os.system("python train.py") # blocking training run
59
- load_model()
60
- return {"status": "Training complete and model reloaded"}
 
1
  # app.py
2
+ import os, json, subprocess
3
+ from fastapi import BackgroundTasks, FastAPI, UploadFile
4
  from transformers import AutoImageProcessor, BeitForImageClassification
5
  from PIL import Image
6
  import torch
 
28
  processor, model = None, None
29
  print(f"⚠️ Skipping model load: {e}")
30
 
31
+ def run_training():
32
+ try:
33
+ result = subprocess.run(
34
+ ["python", "train2.py"],
35
+ capture_output=True,
36
+ text=True
37
+ )
38
+ if result.returncode == 0 and os.path.exists(MODEL_DIR):
39
+ load_model()
40
+ print("✅ Training complete and model reloaded")
41
+ else:
42
+ print("❌ Training failed:", result.stderr)
43
+ except Exception as e:
44
+ print("⚠️ Training exception:", str(e))
45
+
46
  @app.on_event("startup")
47
  def startup_event():
48
  if os.path.exists(MODEL_DIR):
 
69
  }
70
 
71
  @app.post("/train")
72
+ async def train_endpoint(background_tasks: BackgroundTasks):
73
+ # Schedule the training in the background
74
+ background_tasks.add_task(run_training)
75
+ return {"status": "Training started in background"}
train2.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoImageProcessor,
7
+ BeitForImageClassification,
8
+ TrainingArguments,
9
+ Trainer
10
+ )
11
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12
+
13
+ # ----------------------------
14
+ # CONFIG
15
+ # ----------------------------
16
+ MODEL_NAME = "microsoft/beit-base-patch16-224"
17
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/outputs/beit-retina"))
18
+ NUM_CLASSES = 6 # retina disease classes
19
+
20
+ # Make sure output directory exists
21
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
22
+
23
+ # ----------------------------
24
+ # LOAD DATASET
25
+ # ----------------------------
26
+ # Example: Replace this with your retina dataset
27
+ # You can load a Hugging Face dataset or your own image folder dataset
28
+ # Dataset format: train/valid/test folders each containing subfolders by class name
29
+ dataset = load_dataset("imagefolder", data_dir="data")
30
+
31
+ # ----------------------------
32
+ # PREPROCESSOR
33
+ # ----------------------------
34
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
35
+
36
+ def transform(example):
37
+ inputs = processor(example["image"], return_tensors="pt")
38
+ inputs["label"] = example["label"]
39
+ return inputs
40
+
41
+ # Map preprocessing
42
+ dataset = dataset.with_transform(transform)
43
+
44
+ # ----------------------------
45
+ # MODEL
46
+ # ----------------------------
47
+ model = BeitForImageClassification.from_pretrained(
48
+ MODEL_NAME,
49
+ num_labels=NUM_CLASSES,
50
+ ignore_mismatched_sizes=True
51
+ )
52
+
53
+ # ----------------------------
54
+ # METRICS
55
+ # ----------------------------
56
+ def compute_metrics(eval_pred):
57
+ logits, labels = eval_pred
58
+ preds = logits.argmax(axis=-1)
59
+ return {
60
+ "accuracy": accuracy_score(labels, preds),
61
+ "precision": precision_score(labels, preds, average="macro"),
62
+ "recall": recall_score(labels, preds, average="macro"),
63
+ "f1": f1_score(labels, preds, average="macro"),
64
+ }
65
+
66
+ # ----------------------------
67
+ # TRAINING ARGS
68
+ # ----------------------------
69
+ args = TrainingArguments(
70
+ output_dir=OUTPUT_DIR,
71
+ evaluation_strategy="epoch",
72
+ save_strategy="epoch",
73
+ learning_rate=5e-5,
74
+ per_device_train_batch_size=16,
75
+ per_device_eval_batch_size=16,
76
+ num_train_epochs=5,
77
+ weight_decay=0.01,
78
+ logging_dir=os.path.join(OUTPUT_DIR, "logs"),
79
+ push_to_hub=False
80
+ )
81
+
82
+ # ----------------------------
83
+ # TRAINER
84
+ # ----------------------------
85
+ trainer = Trainer(
86
+ model=model,
87
+ args=args,
88
+ train_dataset=dataset["train"],
89
+ eval_dataset=dataset["validation"],
90
+ tokenizer=processor,
91
+ compute_metrics=compute_metrics
92
+ )
93
+
94
+ # ----------------------------
95
+ # TRAIN
96
+ # ----------------------------
97
+ trainer.train()
98
+
99
+ # ----------------------------
100
+ # SAVE FINAL MODEL + LABELS
101
+ # ----------------------------
102
+ trainer.save_model(OUTPUT_DIR)
103
+ processor.save_pretrained(OUTPUT_DIR)
104
+
105
+ # Save class labels mapping
106
+ labels = dataset["train"].features["label"].names
107
+ with open(os.path.join(OUTPUT_DIR, "labels.json"), "w") as f:
108
+ json.dump(labels, f)
109
+
110
+ print(f"✅ Model and processor saved to {OUTPUT_DIR}")