efnanaladagg commited on
Commit
6f6eb85
·
0 Parent(s):

Clean push

Browse files
.dockerignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local env
2
+ .venv/
3
+ __pycache__/
4
+ *.pyc
5
+
6
+ # Local data (do not ship)
7
+ data/
8
+
9
+ # OS / IDE
10
+ .vscode/
11
+ .DS_Store
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Git LFS
2
+ artifacts/model.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
+ *.arrow filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python virtual env
2
+ .venv/
3
+ venv/
4
+
5
+ # Python cache
6
+ __pycache__/
7
+ *.pyc
8
+
9
+ # Local datasets / splits
10
+ data/
11
+
12
+ # OS / IDE
13
+ .DS_Store
14
+ Thumbs.db
15
+ .vscode/
16
+ .idea/
17
+
18
+ # Dataset artifacts / HF datasets shards
19
+ data/splits/**/*.arrow
20
+ data/splits/**/dataset_info.json
21
+ data/splits/**/state.json
22
+ data/splits/
23
+
24
+ # Jupyter / checkpoints
25
+ .ipynb_checkpoints/
26
+ .venv/
27
+ data/splits/.git_old/
28
+ *.arrow
29
+
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # HF Spaces containers run as UID 1000; creating a user avoids permission issues.
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+
7
+ WORKDIR /app
8
+
9
+ # Copy only requirement file first for better Docker layer caching
10
+ COPY requirements.txt /app/requirements.txt
11
+ RUN pip install --no-cache-dir -r /app/requirements.txt
12
+
13
+ # Copy the app code + artifacts
14
+ COPY src /app/src
15
+ COPY artifacts /app/artifacts
16
+ COPY README.md /app/README.md
17
+
18
+ # Spaces expect the app to listen on port 7860
19
+ EXPOSE 7860
20
+
21
+ # Ensure python can import "src.*"
22
+ ENV PYTHONPATH=/app
23
+
24
+ CMD ["uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vehicle Damage Classifier
3
+ emoji: 📉
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ short_description: MIS453 Midterm Project
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
artifacts/label_names.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "F_Breakage",
3
+ "F_Crushed",
4
+ "F_Normal",
5
+ "R_Breakage",
6
+ "R_Crushed",
7
+ "R_Normal"
8
+ ]
artifacts/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64fd1fb63cb60fb27e8347308536d2d39df25e40b621d0e3a20699d81670d8ba
3
+ size 44789387
src/app/README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vehicle Damage Classification Using Computer Vision
2
+
3
+ This project is developed as part of the **MIS453 Midterm Project**.
4
+ The goal is to build an **end-to-end machine learning application** that classifies vehicle damage images into predefined categories and serves predictions via a backend API.
5
+
6
+ ---
7
+
8
+ ## 📌 Problem Definition
9
+
10
+ Given a single RGB image of a vehicle, the system predicts **one damage class** among the following:
11
+
12
+ - F_Normal
13
+ - F_Breakage
14
+ - F_Crushed
15
+ - R_Normal
16
+ - R_Breakage
17
+ - R_Crushed
18
+
19
+ The task is strictly **multi-class image classification**.
20
+
21
+ ---
22
+
23
+ ## 📊 Dataset
24
+
25
+ - **Source:** DrBimmer / Comprehensive Car Damage (Hugging Face)
26
+ - **Total samples:** 2300 images
27
+ - **Split:** Stratified Train / Validation (80% / 20%)
28
+ - **Classes:** 6 (verified programmatically)
29
+
30
+ The dataset was visually inspected and class distributions were analyzed before model training.
31
+
32
+ ---
33
+
34
+ ## 🧠 Model & Training
35
+
36
+ - **Architecture:** ResNet18
37
+ - **Training mode:** Offline (no pretrained weight download)
38
+ - **Loss:** CrossEntropyLoss with class weighting
39
+ - **Input size:** 224 × 224
40
+ - **Device:** CPU
41
+ - **Output:** Trained model artifact saved to disk
42
+
43
+ Artifacts generated after training:
44
+ - `artifacts/model.pt`
45
+ - `artifacts/label_names.json`
46
+ - `artifacts/confusion_matrix.pt`
47
+
48
+ ---
49
+
50
+ ## 📈 Evaluation
51
+
52
+ - **Validation Accuracy:** ~0.57
53
+ - **Metrics:** Confusion matrix + class-wise Precision / Recall / F1
54
+ - Strong performance on **Front** classes, lower recall on **Rear** damage types (expected due to visual similarity and class imbalance).
55
+
56
+ The goal of this project is **not accuracy maximization**, but demonstrating a **correct and reproducible ML pipeline**.
57
+
58
+ ---
59
+
60
+ ## 🚀 Backend API (FastAPI)
61
+
62
+ A FastAPI backend is implemented to serve predictions using the trained model artifact.
63
+
64
+ ### Available Endpoints
65
+
66
+ - `GET /health`
67
+ Returns API status and class information.
68
+
69
+ - `POST /predict`
70
+ Accepts an image file and returns:
71
+ - predicted class
72
+ - confidence score
73
+ - top-3 predictions
74
+
75
+ ---
76
+
77
+ ## ▶️ Local Setup & Run
78
+
79
+ ### 1. Install dependencies
80
+ ```bash
81
+ pip install -r requirements.txt
82
+
src/app/inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/app/inference.py
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Dict, Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import transforms, models
9
+ from PIL import Image
10
+
11
+
12
+ # --- Paths (relative to project root) ---
13
+ ARTIFACTS_DIR = Path("artifacts")
14
+ CKPT_PATH = ARTIFACTS_DIR / "model.pt"
15
+ LABELS_PATH = ARTIFACTS_DIR / "label_names.json"
16
+
17
+ IMG_SIZE = 224
18
+
19
+
20
+ def get_device() -> torch.device:
21
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+
24
+ def load_label_names() -> List[str]:
25
+ if not LABELS_PATH.exists():
26
+ raise FileNotFoundError(f"Missing {LABELS_PATH}. Run training first to create artifacts.")
27
+ return json.loads(LABELS_PATH.read_text(encoding="utf-8"))
28
+
29
+
30
+ def build_model(num_classes: int) -> nn.Module:
31
+ # OFFLINE SAFE: no pretrained downloads
32
+ model = models.resnet18(weights=None)
33
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
34
+ return model
35
+
36
+
37
+ def load_model() -> Tuple[nn.Module, List[str], torch.device]:
38
+ """
39
+ Loads the trained model artifact and label names once.
40
+ Returns (model, label_names, device).
41
+ """
42
+ if not CKPT_PATH.exists():
43
+ raise FileNotFoundError(f"Missing {CKPT_PATH}. Train and save model first.")
44
+
45
+ label_names = load_label_names()
46
+ num_classes = len(label_names)
47
+
48
+ device = get_device()
49
+ model = build_model(num_classes)
50
+
51
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
52
+ model.load_state_dict(ckpt["model_state_dict"])
53
+
54
+ model.to(device)
55
+ model.eval()
56
+ return model, label_names, device
57
+
58
+
59
+ def get_preprocess() -> transforms.Compose:
60
+ # Must match training/evaluation preprocessing
61
+ return transforms.Compose([
62
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
65
+ std=[0.229, 0.224, 0.225]),
66
+ ])
67
+
68
+
69
+ @torch.no_grad()
70
+ def predict_image(
71
+ model: nn.Module,
72
+ label_names: List[str],
73
+ device: torch.device,
74
+ image: Image.Image,
75
+ top_k: int = 3
76
+ ) -> Dict[str, Any]:
77
+ """
78
+ Predicts class probabilities for a single PIL image.
79
+ Returns predicted class, confidence, and top-k list.
80
+ """
81
+ tf = get_preprocess()
82
+ x = tf(image.convert("RGB")).unsqueeze(0).to(device) # (1,3,H,W)
83
+
84
+ logits = model(x)
85
+ probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu()
86
+
87
+ pred_id = int(torch.argmax(probs).item())
88
+ pred_label = label_names[pred_id]
89
+ pred_conf = float(probs[pred_id].item())
90
+
91
+ k = min(top_k, len(label_names))
92
+ top = torch.topk(probs, k=k)
93
+ topk: List[Dict[str, float]] = []
94
+ for score, idx in zip(top.values.tolist(), top.indices.tolist()):
95
+ topk.append({"label": label_names[int(idx)], "confidence": float(score)})
96
+
97
+ # all_probs is sometimes useful for debugging/UI charts
98
+ all_probs = {label_names[i]: float(probs[i].item()) for i in range(len(label_names))}
99
+
100
+ return {
101
+ "predicted_class": pred_label,
102
+ "confidence": pred_conf,
103
+ "top_k": topk,
104
+ "all_probs": all_probs,
105
+ }
106
+ # --- IGNORE ---
107
+ # This module provides functions to load a trained ResNet18 model,
108
+ # preprocess images, and perform inference to obtain class predictions
109
+ # and confidence scores for the "comprehensive-car-damage" dataset.
src/app/main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/app/main.py
2
+ from io import BytesIO
3
+ from typing import Any, Dict
4
+
5
+ from fastapi import FastAPI, UploadFile, File, HTTPException
6
+ from PIL import Image
7
+
8
+ from src.app.inference import load_model, predict_image
9
+
10
+ app = FastAPI(
11
+ title="Vehicle Damage Classifier API",
12
+ version="1.0.0",
13
+ description="Predicts one of 6 vehicle damage classes from an uploaded image."
14
+ )
15
+
16
+ # Load once at startup (no retraining, no repeated loading per request)
17
+ MODEL, LABEL_NAMES, DEVICE = load_model()
18
+
19
+
20
+ @app.get("/health")
21
+ def health() -> Dict[str, Any]:
22
+ return {
23
+ "status": "ok",
24
+ "device": str(DEVICE),
25
+ "num_classes": len(LABEL_NAMES),
26
+ "classes": LABEL_NAMES,
27
+ }
28
+
29
+
30
+ @app.post("/predict")
31
+ async def predict(file: UploadFile = File(...)) -> Dict[str, Any]:
32
+ # Basic content-type guard (not perfect, but prevents obvious non-images)
33
+ if file.content_type is None or not file.content_type.startswith("image/"):
34
+ raise HTTPException(status_code=400, detail="Please upload an image file.")
35
+
36
+ try:
37
+ content = await file.read()
38
+ img = Image.open(BytesIO(content))
39
+ except Exception:
40
+ raise HTTPException(status_code=400, detail="Invalid image file.")
41
+
42
+ result = predict_image(MODEL, LABEL_NAMES, DEVICE, img, top_k=3)
43
+
44
+ return {
45
+ "filename": file.filename,
46
+ **result,
47
+ }
48
+ # --- IGNORE ---
49
+ # This is the main FastAPI application defining endpoints for health check
50
+ # and image prediction using a pre-loaded vehicle damage classification model.
src/app/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ pillow
5
+ torch
6
+ torchvision
7
+ datasets
src/step10_demo_request.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import requests
3
+ import mimetypes
4
+ import os
5
+
6
+ def main():
7
+ if len(sys.argv) < 2:
8
+ print("Usage: python src/step10_demo_request.py path/to/image.jpg")
9
+ sys.exit(1)
10
+
11
+ img_path = sys.argv[1]
12
+ url = "http://127.0.0.1:8000/predict"
13
+
14
+ mime_type, _ = mimetypes.guess_type(img_path)
15
+ if mime_type is None:
16
+ mime_type = "application/octet-stream"
17
+
18
+ print(f"Sending: {img_path} -> {url} (mime: {mime_type})")
19
+
20
+ try:
21
+ with open(img_path, "rb") as f:
22
+ files = {"file": (os.path.basename(img_path), f, mime_type)}
23
+ r = requests.post(url, files=files, timeout=60)
24
+
25
+ print("Request sent.")
26
+ print("Status code:", r.status_code)
27
+ print("Response headers:", dict(r.headers))
28
+
29
+ try:
30
+ print("JSON response:", r.json())
31
+ except ValueError:
32
+ print("Non-JSON response text:", r.text)
33
+
34
+ # Optional: raise for non-2xx to get exception trace if desired
35
+ # r.raise_for_status()
36
+
37
+ except requests.exceptions.RequestException as e:
38
+ print("Request failed:", repr(e))
39
+ except FileNotFoundError:
40
+ print("File not found:", img_path)
41
+ except Exception as e:
42
+ print("Unexpected error:", repr(e))
43
+
44
+ if __name__ == "__main__":
45
+ main()
46
+ # --- IGNORE ---
47
+ # This script demonstrates how to send an image file to the FastAPI
48
+ # prediction endpoint and print the response containing predicted
49
+ # vehicle damage classes and confidence scores.
src/step2_show_sample.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import matplotlib.pyplot as plt
3
+
4
+ def main():
5
+ ds = load_dataset("DrBimmer/comprehensive-car-damage")
6
+ sample = ds["train"][0]
7
+
8
+ image = sample["image"]
9
+ label = sample["label"]
10
+
11
+ print("Label ID:", label)
12
+ print("Label Name:", ds["train"].features["label"].names[label])
13
+
14
+ plt.imshow(image)
15
+ plt.axis("off")
16
+ plt.title(ds["train"].features["label"].names[label])
17
+ plt.show()
18
+
19
+ if __name__ == "__main__":
20
+ main()
21
+ # This script loads a sample from the "comprehensive-car-damage" dataset,
22
+ # prints its label ID and name, and displays the image using matplotlib.
23
+
src/step3_show_all_classes.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import matplotlib.pyplot as plt
3
+
4
+ def main():
5
+ ds = load_dataset("DrBimmer/comprehensive-car-damage")
6
+ train_ds = ds["train"]
7
+ label_names = train_ds.features["label"].names
8
+
9
+ shown = set()
10
+ images = []
11
+ titles = []
12
+
13
+ for sample in train_ds:
14
+ label = sample["label"]
15
+ if label not in shown:
16
+ images.append(sample["image"])
17
+ titles.append(label_names[label])
18
+ shown.add(label)
19
+ if len(shown) == len(label_names):
20
+ break
21
+
22
+ plt.figure(figsize=(12, 8))
23
+ for i, (img, title) in enumerate(zip(images, titles)):
24
+ plt.subplot(2, 3, i + 1)
25
+ plt.imshow(img)
26
+ plt.title(title)
27
+ plt.axis("off")
28
+
29
+ plt.tight_layout()
30
+ plt.show()
31
+
32
+ if __name__ == "__main__":
33
+ main()
34
+ # This script loads the "comprehensive-car-damage" dataset,
35
+ # iterates through the training set to find and display one image for each damage class
36
+ # using matplotlib in a grid layout.
src/step4_make_splits.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from datasets import load_dataset, DatasetDict
3
+ from pathlib import Path
4
+
5
+ SEED = 42
6
+ TEST_SIZE = 0.2
7
+ OUT_DIR = Path("data/splits/comprehensive-car-damage_seed42_test0p2")
8
+
9
+ def dist(ds_split):
10
+ c = Counter(ds_split["label"])
11
+ names = ds_split.features["label"].names
12
+ total = len(ds_split)
13
+ rows = []
14
+ for k in range(len(names)):
15
+ v = c.get(k, 0)
16
+ rows.append((names[k], v, v/total if total else 0))
17
+ return rows
18
+
19
+ def print_dist(title, ds_split):
20
+ print(f"\n{title} (n={len(ds_split)})")
21
+ for name, v, p in dist(ds_split):
22
+ print(f"- {name:<10}: {v:>4} ({p*100:>5.1f}%)")
23
+
24
+ def main():
25
+ ds = load_dataset("DrBimmer/comprehensive-car-damage")
26
+ train = ds["train"]
27
+
28
+ split = train.train_test_split(
29
+ test_size=TEST_SIZE,
30
+ seed=SEED,
31
+ stratify_by_column="label"
32
+ )
33
+
34
+ # Rename for clarity: test -> val
35
+ splits = DatasetDict({"train": split["train"], "val": split["test"]})
36
+
37
+ print_dist("TRAIN", splits["train"])
38
+ print_dist("VAL", splits["val"])
39
+
40
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
41
+ splits.save_to_disk(str(OUT_DIR))
42
+ print(f"\nSaved splits to: {OUT_DIR}")
43
+
44
+ if __name__ == "__main__":
45
+ main()
src/step5_verify_load_splits.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+
3
+ SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
4
+
5
+ def main():
6
+ splits = load_from_disk(SPLIT_DIR)
7
+ print("Loaded keys:", list(splits.keys()))
8
+ print("train:", len(splits["train"]), "val:", len(splits["val"]))
9
+
10
+ # sanity: label names
11
+ names = splits["train"].features["label"].names
12
+ print("Label names:", names)
13
+
14
+ if __name__ == "__main__":
15
+ main()
src/step6_dataloaders.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from datasets import load_from_disk
5
+ from collections import Counter
6
+
7
+ SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
8
+ BATCH_SIZE = 16
9
+ NUM_WORKERS = 0 # For Windows compatibility; set higher for Linux/Mac
10
+ IMG_SIZE = 224
11
+
12
+ def compute_class_weights(labels, num_classes):
13
+ c = Counter(labels)
14
+ total = len(labels)
15
+ # simple inverse frequency weighting
16
+ weights = []
17
+ for k in range(num_classes):
18
+ freq = c.get(k, 1) / total
19
+ weights.append(1.0 / freq)
20
+ w = torch.tensor(weights, dtype=torch.float)
21
+ # normalize (optional)
22
+ w = w / w.mean()
23
+ return w
24
+
25
+ def main():
26
+ splits = load_from_disk(SPLIT_DIR)
27
+ train_ds = splits["train"]
28
+ val_ds = splits["val"]
29
+
30
+ label_names = train_ds.features["label"].names
31
+ num_classes = len(label_names)
32
+ print("Classes:", label_names)
33
+
34
+ train_tf = transforms.Compose([
35
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
36
+ transforms.RandomHorizontalFlip(p=0.5),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
39
+ std=[0.229, 0.224, 0.225]),
40
+ ])
41
+
42
+ val_tf = transforms.Compose([
43
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225]),
47
+ ])
48
+
49
+ def transform_batch(examples, tf):
50
+ images = [tf(img.convert("RGB")) for img in examples["image"]]
51
+ labels = torch.tensor(examples["label"], dtype=torch.long)
52
+ return {"pixel_values": torch.stack(images), "labels": labels}
53
+
54
+ def collate_train(batch):
55
+ # batch: list of dicts from HF dataset rows
56
+ imgs = [row["image"] for row in batch]
57
+ labels = [row["label"] for row in batch]
58
+ return transform_batch({"image": imgs, "label": labels}, train_tf)
59
+
60
+ def collate_val(batch):
61
+ imgs = [row["image"] for row in batch]
62
+ labels = [row["label"] for row in batch]
63
+ return transform_batch({"image": imgs, "label": labels}, val_tf)
64
+
65
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
66
+ num_workers=NUM_WORKERS, collate_fn=collate_train)
67
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
68
+ num_workers=NUM_WORKERS, collate_fn=collate_val)
69
+
70
+ # sanity check: one batch
71
+ batch = next(iter(train_loader))
72
+ print("Batch keys:", batch.keys())
73
+ print("pixel_values shape:", batch["pixel_values"].shape) # (B, C, H, W)
74
+ print("labels shape:", batch["labels"].shape)
75
+ print("labels sample:", batch["labels"][:8].tolist())
76
+ print("labels sample names:", [label_names[i] for i in batch["labels"][:8].tolist()])
77
+
78
+ # class weights (train)
79
+ w = compute_class_weights(train_ds["label"], num_classes)
80
+ print("Class weights:", {label_names[i]: float(w[i]) for i in range(num_classes)})
81
+
82
+ if __name__ == "__main__":
83
+ main()
src/step7_train_resnet18.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from collections import Counter
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import transforms, models
9
+ from datasets import load_from_disk
10
+
11
+ # --- Fixed inputs for reproducibility and consistent artifact paths ---
12
+ SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
13
+ ART_DIR = Path("artifacts")
14
+ ART_DIR.mkdir(parents=True, exist_ok=True)
15
+
16
+ IMG_SIZE = 224
17
+ BATCH_SIZE = 16
18
+ NUM_WORKERS = 0 # Windows-safe default (avoid multiprocessing issues)
19
+ EPOCHS = 8
20
+ LR = 3e-4
21
+ SEED = 42
22
+
23
+ def set_seed(seed: int):
24
+ # Ensures consistent shuffling and initialization across runs
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+
28
+ def compute_class_weights(labels, num_classes):
29
+ # Inverse-frequency weights to reduce bias toward majority classes
30
+ c = Counter(labels)
31
+ total = len(labels)
32
+ weights = []
33
+ for k in range(num_classes):
34
+ freq = c.get(k, 1) / total
35
+ weights.append(1.0 / freq)
36
+ w = torch.tensor(weights, dtype=torch.float)
37
+ # Normalize weights so average weight ≈ 1 (stable loss scale)
38
+ w = w / w.mean()
39
+ return w
40
+
41
+ def accuracy(logits, labels):
42
+ # Simple top-1 accuracy
43
+ preds = logits.argmax(dim=1)
44
+ return (preds == labels).float().mean().item()
45
+
46
+ def main():
47
+ set_seed(SEED)
48
+
49
+ # --- Device selection (CPU is fine; CUDA if available) ---
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ print("Device:", device)
52
+
53
+ # --- Load the *saved* stratified splits (do NOT re-split each run) ---
54
+ splits = load_from_disk(SPLIT_DIR)
55
+ train_ds = splits["train"]
56
+ val_ds = splits["val"]
57
+
58
+ # --- Label metadata (source of truth for class order) ---
59
+ label_names = train_ds.features["label"].names
60
+ num_classes = len(label_names)
61
+ print("Classes:", label_names)
62
+
63
+ # Save label map alongside the model artifact (needed for inference/API)
64
+ with open(ART_DIR / "label_names.json", "w", encoding="utf-8") as f:
65
+ json.dump(label_names, f, ensure_ascii=False, indent=2)
66
+
67
+ # --- Image preprocessing ---
68
+ # Train: small augmentation (flip) to improve generalization
69
+ # Val: deterministic transforms only
70
+ train_tf = transforms.Compose([
71
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
72
+ transforms.RandomHorizontalFlip(p=0.5),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225]),
76
+ ])
77
+
78
+ val_tf = transforms.Compose([
79
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
82
+ std=[0.229, 0.224, 0.225]),
83
+ ])
84
+
85
+ # --- HF Dataset -> PyTorch batch conversion (collate_fn) ---
86
+ # We apply torchvision transforms inside collate_fn because HF stores PIL Images.
87
+ def transform_batch(examples, tf):
88
+ images = [tf(img.convert("RGB")) for img in examples["image"]]
89
+ labels = torch.tensor(examples["label"], dtype=torch.long)
90
+ return {"pixel_values": torch.stack(images), "labels": labels}
91
+
92
+ def collate_train(batch):
93
+ imgs = [row["image"] for row in batch]
94
+ labels = [row["label"] for row in batch]
95
+ return transform_batch({"image": imgs, "label": labels}, train_tf)
96
+
97
+ def collate_val(batch):
98
+ imgs = [row["image"] for row in batch]
99
+ labels = [row["label"] for row in batch]
100
+ return transform_batch({"image": imgs, "label": labels}, val_tf)
101
+
102
+ # --- DataLoaders (train shuffled, val not shuffled) ---
103
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
104
+ num_workers=NUM_WORKERS, collate_fn=collate_train)
105
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
106
+ num_workers=NUM_WORKERS, collate_fn=collate_val)
107
+
108
+ # --- Transfer Learning model ---
109
+ # Start from pretrained ImageNet weights, replace final classifier head for 6 classes.
110
+ USE_PRETRAINED = False
111
+ weights = models.ResNet18_Weights.DEFAULT if USE_PRETRAINED else None
112
+ model = models.resnet18(weights=weights)
113
+ in_features = model.fc.in_features
114
+ model.fc = nn.Linear(in_features, num_classes)
115
+ model = model.to(device)
116
+
117
+ # --- Loss with class weights (handles mild imbalance) ---
118
+ class_w = compute_class_weights(train_ds["label"], num_classes).to(device)
119
+ criterion = nn.CrossEntropyLoss(weight=class_w)
120
+
121
+ # --- Optimizer ---
122
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
123
+
124
+ # --- Checkpointing: keep the best model by validation accuracy ---
125
+ best_val_acc = -1.0
126
+ best_path = ART_DIR / "model.pt"
127
+
128
+ for epoch in range(1, EPOCHS + 1):
129
+ # ===== TRAIN LOOP =====
130
+ model.train()
131
+ train_loss = 0.0
132
+ train_acc = 0.0
133
+ n_train = 0
134
+
135
+ for batch in train_loader:
136
+ x = batch["pixel_values"].to(device)
137
+ y = batch["labels"].to(device)
138
+
139
+ optimizer.zero_grad(set_to_none=True)
140
+ logits = model(x)
141
+ loss = criterion(logits, y)
142
+ loss.backward()
143
+ optimizer.step()
144
+
145
+ bs = y.size(0)
146
+ train_loss += loss.item() * bs
147
+ train_acc += accuracy(logits.detach(), y) * bs
148
+ n_train += bs
149
+
150
+ train_loss /= n_train
151
+ train_acc /= n_train
152
+
153
+ # ===== VALIDATION LOOP =====
154
+ model.eval()
155
+ val_loss = 0.0
156
+ val_acc = 0.0
157
+ n_val = 0
158
+
159
+ with torch.no_grad():
160
+ for batch in val_loader:
161
+ x = batch["pixel_values"].to(device)
162
+ y = batch["labels"].to(device)
163
+
164
+ logits = model(x)
165
+ loss = criterion(logits, y)
166
+
167
+ bs = y.size(0)
168
+ val_loss += loss.item() * bs
169
+ val_acc += accuracy(logits, y) * bs
170
+ n_val += bs
171
+
172
+ val_loss /= n_val
173
+ val_acc /= n_val
174
+
175
+ print(f"Epoch {epoch:02d}/{EPOCHS} | "
176
+ f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
177
+ f"val loss {val_loss:.4f} acc {val_acc:.4f}")
178
+
179
+ # Save best checkpoint
180
+ if val_acc > best_val_acc:
181
+ best_val_acc = val_acc
182
+ torch.save({
183
+ "model_state_dict": model.state_dict(),
184
+ "label_names": label_names,
185
+ "img_size": IMG_SIZE,
186
+ "arch": "resnet18",
187
+ }, best_path)
188
+ print(f" -> saved best to {best_path} (val_acc={best_val_acc:.4f})")
189
+
190
+ print("\nTraining complete.")
191
+ print("Best val acc:", best_val_acc)
192
+
193
+ if __name__ == "__main__":
194
+ main()
src/step8_evaluate.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import transforms, models
8
+ from datasets import load_from_disk
9
+
10
+ SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
11
+ CKPT_PATH = Path("artifacts/model.pt")
12
+ LABELS_PATH = Path("artifacts/label_names.json")
13
+
14
+ IMG_SIZE = 224
15
+ BATCH_SIZE = 32
16
+ NUM_WORKERS = 0
17
+
18
+ def confusion_matrix_torch(y_true, y_pred, num_classes):
19
+ cm = torch.zeros((num_classes, num_classes), dtype=torch.int64)
20
+ for t, p in zip(y_true, y_pred):
21
+ cm[t, p] += 1
22
+ return cm
23
+
24
+ def precision_recall_f1(cm):
25
+ # cm rows: true, cols: pred
26
+ num_classes = cm.size(0)
27
+ metrics = []
28
+ for i in range(num_classes):
29
+ tp = cm[i, i].item()
30
+ fp = cm[:, i].sum().item() - tp
31
+ fn = cm[i, :].sum().item() - tp
32
+
33
+ prec = tp / (tp + fp) if (tp + fp) else 0.0
34
+ rec = tp / (tp + fn) if (tp + fn) else 0.0
35
+ f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0
36
+ metrics.append((prec, rec, f1))
37
+ return metrics
38
+
39
+ def main():
40
+ # Load label names (source of truth for readable reporting)
41
+ label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8"))
42
+ num_classes = len(label_names)
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ print("Device:", device)
46
+ print("Classes:", label_names)
47
+
48
+ # Load val split from disk
49
+ splits = load_from_disk(SPLIT_DIR)
50
+ val_ds = splits["val"]
51
+
52
+ # Deterministic val transforms
53
+ val_tf = transforms.Compose([
54
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225]),
58
+ ])
59
+
60
+ def collate_val(batch):
61
+ imgs = [val_tf(row["image"].convert("RGB")) for row in batch]
62
+ labels = torch.tensor([row["label"] for row in batch], dtype=torch.long)
63
+ return {"pixel_values": torch.stack(imgs), "labels": labels}
64
+
65
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
66
+ num_workers=NUM_WORKERS, collate_fn=collate_val)
67
+
68
+ # Rebuild model architecture and load checkpoint weights
69
+ model = models.resnet18(weights=None)
70
+ in_features = model.fc.in_features
71
+ model.fc = nn.Linear(in_features, num_classes)
72
+
73
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
74
+ model.load_state_dict(ckpt["model_state_dict"])
75
+ model = model.to(device)
76
+ model.eval()
77
+
78
+ y_true_all = []
79
+ y_pred_all = []
80
+
81
+ with torch.no_grad():
82
+ for batch in val_loader:
83
+ x = batch["pixel_values"].to(device)
84
+ y = batch["labels"].to(device)
85
+
86
+ logits = model(x)
87
+ preds = logits.argmax(dim=1)
88
+
89
+ y_true_all.append(y.cpu())
90
+ y_pred_all.append(preds.cpu())
91
+
92
+ y_true = torch.cat(y_true_all)
93
+ y_pred = torch.cat(y_pred_all)
94
+
95
+ acc = (y_true == y_pred).float().mean().item()
96
+ print(f"\nVAL Accuracy: {acc:.4f}")
97
+
98
+ cm = confusion_matrix_torch(y_true, y_pred, num_classes)
99
+ print("\nConfusion Matrix (rows=true, cols=pred):")
100
+ print(cm)
101
+
102
+ metrics = precision_recall_f1(cm)
103
+ print("\nPer-class metrics:")
104
+ for i, (prec, rec, f1) in enumerate(metrics):
105
+ print(f"- {label_names[i]:<10} | P {prec:.3f} | R {rec:.3f} | F1 {f1:.3f}")
106
+
107
+ # Save CM for later reporting
108
+ out_path = Path("artifacts/confusion_matrix.pt")
109
+ torch.save({"confusion_matrix": cm, "label_names": label_names, "val_acc": acc}, out_path)
110
+ print(f"\nSaved confusion matrix to: {out_path}")
111
+
112
+ if __name__ == "__main__":
113
+ main()
114
+ # This script evaluates a trained ResNet18 model on the validation split of the
115
+ # "comprehensive-car-damage" dataset, computes accuracy, confusion matrix,
116
+ # precision, recall, and F1-score for each class, and saves the confusion matrix to disk.
src/step9_infer_from_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms, models
8
+ from datasets import load_from_disk
9
+
10
+ SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
11
+ CKPT_PATH = Path("artifacts/model.pt")
12
+ LABELS_PATH = Path("artifacts/label_names.json")
13
+
14
+ IMG_SIZE = 224
15
+ SEED = 42
16
+
17
+ def softmax_probs(logits: torch.Tensor) -> torch.Tensor:
18
+ return torch.softmax(logits, dim=1)
19
+
20
+ def main():
21
+ random.seed(SEED)
22
+ torch.manual_seed(SEED)
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8"))
26
+ num_classes = len(label_names)
27
+
28
+ # Load saved splits and pick one random sample from VAL (more meaningful than train)
29
+ splits = load_from_disk(SPLIT_DIR)
30
+ val_ds = splits["val"]
31
+ idx = random.randint(0, len(val_ds) - 1)
32
+ sample = val_ds[idx]
33
+
34
+ true_id = sample["label"]
35
+ true_name = label_names[true_id]
36
+
37
+ tf = transforms.Compose([
38
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
41
+ std=[0.229, 0.224, 0.225]),
42
+ ])
43
+
44
+ x = tf(sample["image"].convert("RGB")).unsqueeze(0) # (1,3,H,W)
45
+
46
+ # Rebuild model arch and load weights
47
+ model = models.resnet18(weights=None)
48
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
49
+
50
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
51
+ model.load_state_dict(ckpt["model_state_dict"])
52
+ model = model.to(device)
53
+ model.eval()
54
+
55
+ with torch.no_grad():
56
+ logits = model(x.to(device))
57
+ probs = softmax_probs(logits).cpu().squeeze(0)
58
+
59
+ pred_id = int(torch.argmax(probs).item())
60
+ pred_name = label_names[pred_id]
61
+ pred_conf = float(probs[pred_id].item())
62
+
63
+ # top-3
64
+ topk = torch.topk(probs, k=3)
65
+ top3 = [(label_names[int(i)], float(v)) for v, i in zip(topk.values, topk.indices)]
66
+
67
+ print(f"Sample index (val): {idx}")
68
+ print(f"TRUE: {true_name} ({true_id})")
69
+ print(f"PRED: {pred_name} ({pred_id}) conf={pred_conf:.4f}")
70
+ print("TOP-3:")
71
+ for name, p in top3:
72
+ print(f"- {name:<10} : {p:.4f}")
73
+
74
+ if __name__ == "__main__":
75
+ main()
76
+ # This script performs inference on a single random sample from the validation split
77
+ # of the "comprehensive-car-damage" dataset using a trained ResNet18 model,
78
+ # and prints the true label, predicted label, confidence, and top-3 predictions.