habeebCycle commited on
Commit
96c3348
·
verified ·
1 Parent(s): 5b2aa01

Upload 5 files

Browse files

Adding training files

Files changed (5) hide show
  1. Dockerfile +33 -0
  2. app.py +48 -0
  3. requirements.txt +12 -0
  4. startup.sh +4 -0
  5. train.py +71 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Create non-root user
4
+ RUN adduser --disabled-password --gecos '' user
5
+ USER user
6
+
7
+ # Environment variables
8
+ ENV HOME=/home/user \
9
+ PATH=/home/user/.local/bin:$PATH \
10
+ PORT=7860
11
+
12
+ WORKDIR $HOME/app
13
+
14
+ # Copy requirements first (better for Docker layer caching)
15
+ COPY --chown=user requirements.txt ./
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy the rest of the application
19
+ COPY --chown=user . .
20
+
21
+ # Expose FastAPI default port for Hugging Face Spaces
22
+ EXPOSE 7860
23
+
24
+ # HF auth picked automatically from env (Spaces provides HF_TOKEN)
25
+ ENV HF_HOME=/home/user/.cache/huggingface \
26
+ TRANSFORMERS_CACHE=/home/user/.cache/huggingface/transformers \
27
+ TORCH_HOME=/home/user/.cache/torch
28
+
29
+ RUN mkdir -p $HF_HOME $TRANSFORMERS_CACHE $TORCH_HOME
30
+ RUN chmod +x startup.sh
31
+
32
+ # Start API
33
+ CMD ["bash", "startup.sh"]
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, json
3
+ from fastapi import FastAPI, UploadFile
4
+ from transformers import AutoImageProcessor, BeitForImageClassification
5
+ from PIL import Image
6
+ import torch
7
+
8
+ MODEL_DIR = "outputs/beit-retina"
9
+ CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
10
+
11
+ app = FastAPI(title="Retina Disease Classifier")
12
+
13
+ # Lazy load model & processor
14
+ processor = None
15
+ model = None
16
+
17
+ def load_model():
18
+ global processor, model, CLASSES
19
+ processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
20
+ model = BeitForImageClassification.from_pretrained(MODEL_DIR)
21
+ with open(os.path.join(MODEL_DIR, "labels.json")) as f:
22
+ CLASSES = json.load(f)
23
+
24
+ @app.on_event("startup")
25
+ def startup_event():
26
+ if os.path.exists(MODEL_DIR):
27
+ load_model()
28
+
29
+ @app.post("/predict")
30
+ async def predict(file: UploadFile):
31
+ if model is None:
32
+ return {"error": "Model not trained yet"}
33
+ img = Image.open(file.file).convert("RGB")
34
+ inputs = processor(images=img, return_tensors="pt")
35
+ with torch.no_grad():
36
+ logits = model(**inputs).logits
37
+ probs = torch.softmax(logits, dim=1)[0].tolist()
38
+ pred_id = int(torch.argmax(logits, dim=1).item())
39
+ return {
40
+ "class_id": CLASSES[pred_id],
41
+ "probabilities": [{CLASSES[i]: float(p) for i, p in enumerate(probs)}]
42
+ }
43
+
44
+ @app.post("/train")
45
+ async def train_endpoint():
46
+ os.system("python train.py") # blocking training run
47
+ load_model()
48
+ return {"status": "Training complete and model reloaded"}
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0+cpu
2
+ torchvision==0.17.0+cpu
3
+ transformers
4
+ datasets
5
+ accelerate
6
+ scikit-learn
7
+ fastapi
8
+ uvicorn[standard]
9
+ pillow
10
+ pydantic==2.8.2
11
+ python-multipart==0.0.9
12
+ huggingface_hub==0.24.6
startup.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ # In HF Spaces with Docker, CUDA is available if a GPU is provisioned.
4
+ exec uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}
train.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ import os, json
3
+ from transformers import AutoImageProcessor, BeitForImageClassification, TrainingArguments, Trainer
4
+ from datasets import load_dataset
5
+ from sklearn.metrics import accuracy_score, f1_score
6
+ import numpy as np
7
+
8
+ CLASSES = ["AMD","DMO","DR","GLC","HR","Normal"]
9
+ MODEL_NAME = "microsoft/beit-base-patch16-224"
10
+
11
+ print("HOME dir:", os.environ.get("HOME"))
12
+ print("HF cache:", os.environ.get("HF_HOME", os.path.join(os.environ["HOME"], ".cache", "huggingface")))
13
+
14
+
15
+ def compute_metrics(eval_pred):
16
+ logits, labels = eval_pred
17
+ preds = np.argmax(logits, axis=1)
18
+ return {
19
+ "accuracy": accuracy_score(labels, preds),
20
+ "f1_weighted": f1_score(labels, preds, average="weighted")
21
+ }
22
+
23
+ def train(output_dir="/outputs/beit-retina", train_dir="data/train", val_dir="data/val", epochs=5, batch_size=16):
24
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
25
+ dataset = load_dataset("imagefolder", data_dir={"train": train_dir, "validation": val_dir})
26
+
27
+ def transform(examples):
28
+ images = [processor(Image.open(p).convert("RGB"), return_tensors="pt")["pixel_values"][0] for p in examples["image"]]
29
+ return {"pixel_values": images}
30
+
31
+ dataset = dataset.cast_column("label", dataset["train"].features["label"].cast(type="ClassLabel", names=CLASSES))
32
+
33
+ model = BeitForImageClassification.from_pretrained(
34
+ MODEL_NAME,
35
+ num_labels=len(CLASSES),
36
+ id2label={i: c for i, c in enumerate(CLASSES)},
37
+ label2id={c: i for i, c in enumerate(CLASSES)}
38
+ )
39
+
40
+ args = TrainingArguments(
41
+ output_dir=output_dir,
42
+ per_device_train_batch_size=batch_size,
43
+ per_device_eval_batch_size=batch_size,
44
+ num_train_epochs=epochs,
45
+ evaluation_strategy="epoch",
46
+ save_strategy="epoch",
47
+ load_best_model_at_end=True,
48
+ metric_for_best_model="f1_weighted",
49
+ logging_steps=50,
50
+ report_to="none"
51
+ )
52
+
53
+ trainer = Trainer(
54
+ model=model,
55
+ args=args,
56
+ train_dataset=dataset["train"],
57
+ eval_dataset=dataset["validation"],
58
+ tokenizer=processor,
59
+ compute_metrics=compute_metrics
60
+ )
61
+
62
+ trainer.train()
63
+ model.save_pretrained(output_dir)
64
+ processor.save_pretrained(output_dir)
65
+
66
+ with open(os.path.join(output_dir, "labels.json"), "w") as f:
67
+ json.dump(CLASSES, f)
68
+ print("✅ Training complete. Model saved at:", output_dir)
69
+
70
+ if __name__ == "__main__":
71
+ train()