Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- README.md +41 -8
- app.py +27 -88
- data/tata_specs.yaml +65 -0
- requirements.txt +1 -6
- tata_id/__init__.py +1 -0
- tata_id/autofill.py +16 -0
- tata_id/color.py +53 -0
- tata_id/kb.py +10 -0
- tata_id/model.py +25 -0
- tata_id/utils.py +20 -0
- training/train_classifier.py +70 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: "4.36.1"
|
| 8 |
app_file: app.py
|
|
@@ -11,11 +11,44 @@ license: mit
|
|
| 11 |
tags:
|
| 12 |
- automotive
|
| 13 |
- computer-vision
|
| 14 |
-
- nlp
|
| 15 |
- gradio
|
| 16 |
-
-
|
| 17 |
---
|
| 18 |
|
| 19 |
-
#
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Tata Car Identifier (Model & Color)
|
| 3 |
+
emoji: 🚘
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: "4.36.1"
|
| 8 |
app_file: app.py
|
|
|
|
| 11 |
tags:
|
| 12 |
- automotive
|
| 13 |
- computer-vision
|
|
|
|
| 14 |
- gradio
|
| 15 |
+
- tata
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# Tata Car Identifier
|
| 19 |
|
| 20 |
+
An image recognition tool tailored for **Tata** cars that identifies **model**, **color**, and **autofills** extra details (year ranges, engine sizes, features) from a **single uploaded image**.
|
| 21 |
+
|
| 22 |
+
## Quickstart
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
python app.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Open the local Gradio URL and upload a Tata car photo.
|
| 29 |
+
|
| 30 |
+
## How it works
|
| 31 |
+
- **Model ID**: Zero-shot CLIP baseline over common Tata models (Nexon, Altroz, Tiago, Punch, Harrier, Safari, Tigor, etc.). Optional fine-tuning script included.
|
| 32 |
+
- **Color**: Dominant body color via KMeans in LAB space with named-color snapping.
|
| 33 |
+
- **Autofill**: Specs pulled from `data/tata_specs.yaml` using the predicted model.
|
| 34 |
+
|
| 35 |
+
## Train on your dataset
|
| 36 |
+
- Put images under `data/your_dataset/images/` and labels in `data/your_dataset/annotations.csv`:
|
| 37 |
+
```csv
|
| 38 |
+
image_path,label
|
| 39 |
+
images/img_001.jpg,Tata Nexon
|
| 40 |
+
```
|
| 41 |
+
- Run:
|
| 42 |
+
```bash
|
| 43 |
+
python training/train_classifier.py --data_root data/your_dataset --annotations data/your_dataset/annotations.csv --out_dir checkpoints/vision
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## FAQ
|
| 47 |
+
**Q: Do I need to train first?**
|
| 48 |
+
A: No. The app ships with a **CLIP zero-shot** baseline that works out-of-the-box. Training improves accuracy.
|
| 49 |
+
|
| 50 |
+
**Q: Which models are supported?**
|
| 51 |
+
A: See `tata_id/kb.py` (MODEL_LIST). You can add more models and update `data/tata_specs.yaml`.
|
| 52 |
+
|
| 53 |
+
**Q: Can it guess year of manufacture?**
|
| 54 |
+
A: We return a **likely year range** per generation. Exact year typically requires VIN/registration lookup.
|
app.py
CHANGED
|
@@ -1,104 +1,43 @@
|
|
| 1 |
-
import os, json
|
| 2 |
-
from typing import
|
| 3 |
-
from PIL import Image
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from car_advisor.cost_estimator import estimate_costs
|
| 10 |
-
from car_advisor.suggestions import predictive_maintenance, advanced_suggestions
|
| 11 |
-
from car_advisor.reporter import export_pdf, export_json
|
| 12 |
-
from car_advisor.scheduler import create_service_ics
|
| 13 |
-
|
| 14 |
-
vision = VisionInference()
|
| 15 |
-
nlp = NLPInference()
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
return Image.open(io.BytesIO(base64.b64decode(obj["image"].split(",")[-1])))
|
| 20 |
-
if isinstance(obj, str):
|
| 21 |
-
return Image.open(obj)
|
| 22 |
-
return obj
|
| 23 |
|
| 24 |
-
def analyze(
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
img = _to_image(it)
|
| 31 |
-
vp = vision.predict(img)
|
| 32 |
-
valid += 1
|
| 33 |
-
if agg is None:
|
| 34 |
-
agg = {k: v for k,v in vp.items()}
|
| 35 |
-
else:
|
| 36 |
-
for k in agg:
|
| 37 |
-
agg[k] += vp.get(k, 0.0)
|
| 38 |
-
except Exception:
|
| 39 |
-
pass
|
| 40 |
-
if agg is None:
|
| 41 |
-
agg = {k: 0.0 for k in vision.labels}
|
| 42 |
-
else:
|
| 43 |
-
for k in agg:
|
| 44 |
-
agg[k] /= max(1, valid)
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
estimate = estimate_costs(top, "configs/parts_costs.yaml", top_k=4)
|
| 51 |
-
pm = predictive_maintenance(car_year=int(year) if year else None, mileage_km=int(mileage_km) if mileage_km else None)
|
| 52 |
-
adv = advanced_suggestions(top_issues=top)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"issues_ranked": fused,
|
| 59 |
-
"estimate": estimate,
|
| 60 |
-
"predictive_maintenance": pm,
|
| 61 |
-
"advanced_suggestions": adv
|
| 62 |
}
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
json_path = "exports/service_report.json"
|
| 67 |
-
ics_path = "exports/service_appointment.ics"
|
| 68 |
-
export_pdf(payload, pdf_path)
|
| 69 |
-
export_json(payload, json_path)
|
| 70 |
-
create_service_ics(ics_path, hours_from_now=48, duration_minutes=60)
|
| 71 |
-
|
| 72 |
-
def to_dl(path):
|
| 73 |
-
with open(path, "rb") as f:
|
| 74 |
-
return (os.path.basename(path), f.read())
|
| 75 |
-
|
| 76 |
-
return payload, to_dl(pdf_path), to_dl(json_path), to_dl(ics_path)
|
| 77 |
-
|
| 78 |
-
with gr.Blocks(fill_height=True) as demo:
|
| 79 |
-
gr.Markdown("## 🚗 Workshop Car Service Advisor")
|
| 80 |
with gr.Row():
|
| 81 |
with gr.Column(scale=1):
|
| 82 |
-
|
| 83 |
-
cust = gr.Textbox(label="Customer reported issue", placeholder="Describe the problem...")
|
| 84 |
-
with gr.Row():
|
| 85 |
-
make = gr.Textbox(label="Make", value="Toyota")
|
| 86 |
-
model = gr.Textbox(label="Model", value="Corolla")
|
| 87 |
-
year = gr.Number(label="Year", value=2017, precision=0)
|
| 88 |
-
with gr.Row():
|
| 89 |
-
mileage = gr.Number(label="Mileage (km)", value=60000, precision=0)
|
| 90 |
-
vin = gr.Textbox(label="VIN", placeholder="Optional")
|
| 91 |
-
with gr.Row():
|
| 92 |
-
name = gr.Textbox(label="Customer Name", value="")
|
| 93 |
-
phone = gr.Textbox(label="Phone", value="")
|
| 94 |
run = gr.Button("Analyze", variant="primary")
|
| 95 |
with gr.Column(scale=1):
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
json_file = gr.File(label="Download JSON")
|
| 99 |
-
ics_file = gr.File(label="Download .ics (appointment)")
|
| 100 |
-
run.click(analyze, inputs=[imgs, cust, make, model, year, mileage, vin, name, phone],
|
| 101 |
-
outputs=[out_json, pdf_file, json_file, ics_file])
|
| 102 |
|
| 103 |
if __name__ == "__main__":
|
| 104 |
demo.launch()
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
from typing import Any, Dict
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
|
| 6 |
+
from tata_id.model import TataModelIdentifier
|
| 7 |
+
from tata_id.color import detect_color
|
| 8 |
+
from tata_id.autofill import load_specs, autofill_details
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
clf = TataModelIdentifier()
|
| 11 |
+
SPECS = load_specs()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
def analyze(image: Image.Image) -> Dict[str, Any]:
|
| 14 |
+
if image is None:
|
| 15 |
+
raise gr.Error("Please upload an image of a Tata car.")
|
| 16 |
+
# Model identification (top3)
|
| 17 |
+
top3 = clf.predict_topk(image, k=3)
|
| 18 |
+
model_top1 = top3[0][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# Color detection
|
| 21 |
+
color = detect_color(image)
|
| 22 |
|
| 23 |
+
# Autofill
|
| 24 |
+
details = autofill_details(model_top1, SPECS)
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
return {
|
| 27 |
+
"predictions": [{"model": m, "probability": round(float(p), 4)} for m,p in top3],
|
| 28 |
+
"color": color,
|
| 29 |
+
"autofill": details,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
+
with gr.Blocks(fill_height=True, theme=gr.themes.Base()) as demo:
|
| 33 |
+
gr.Markdown("## 🚘 Tata Car Identifier — Model, Color, and Specs (Single Image)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
with gr.Row():
|
| 35 |
with gr.Column(scale=1):
|
| 36 |
+
img = gr.Image(type="pil", label="Upload a Tata car image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
run = gr.Button("Analyze", variant="primary")
|
| 38 |
with gr.Column(scale=1):
|
| 39 |
+
out = gr.JSON(label="Results")
|
| 40 |
+
run.click(analyze, inputs=[img], outputs=[out])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
if __name__ == "__main__":
|
| 43 |
demo.launch()
|
data/tata_specs.yaml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models:
|
| 2 |
+
Tata Tiago:
|
| 3 |
+
years: "2016–present"
|
| 4 |
+
body: "Hatchback"
|
| 5 |
+
engines:
|
| 6 |
+
- "1.2L Revotron Petrol"
|
| 7 |
+
- "1.0L iCNG"
|
| 8 |
+
- "EV (Tiago.ev)"
|
| 9 |
+
features:
|
| 10 |
+
- "Dual airbags"
|
| 11 |
+
- "ABS with EBD"
|
| 12 |
+
- "Touchscreen infotainment (variants)"
|
| 13 |
+
Tata Tigor:
|
| 14 |
+
years: "2017–present"
|
| 15 |
+
body: "Compact Sedan"
|
| 16 |
+
engines:
|
| 17 |
+
- "1.2L Revotron Petrol"
|
| 18 |
+
- "1.0L iCNG"
|
| 19 |
+
features:
|
| 20 |
+
- "Projector headlamps"
|
| 21 |
+
- "Rear camera (variants)"
|
| 22 |
+
Tata Altroz:
|
| 23 |
+
years: "2020–present"
|
| 24 |
+
body: "Premium Hatchback"
|
| 25 |
+
engines:
|
| 26 |
+
- "1.2L Petrol"
|
| 27 |
+
- "1.5L Diesel"
|
| 28 |
+
- "EV (Altroz.ev - where applicable)"
|
| 29 |
+
features:
|
| 30 |
+
- "5-star Global NCAP (variants)"
|
| 31 |
+
- "iRA connected car (variants)"
|
| 32 |
+
Tata Punch:
|
| 33 |
+
years: "2021–present"
|
| 34 |
+
body: "Micro SUV"
|
| 35 |
+
engines:
|
| 36 |
+
- "1.2L Petrol"
|
| 37 |
+
- "iCNG"
|
| 38 |
+
features:
|
| 39 |
+
- "Traction modes (AMT variants)"
|
| 40 |
+
Tata Nexon:
|
| 41 |
+
years: "2017–present (facelift 2023)"
|
| 42 |
+
body: "Compact SUV"
|
| 43 |
+
engines:
|
| 44 |
+
- "1.2L Turbo Petrol"
|
| 45 |
+
- "1.5L Diesel"
|
| 46 |
+
- "EV (Nexon.ev)"
|
| 47 |
+
features:
|
| 48 |
+
- "ADAS (facelift variants)"
|
| 49 |
+
- "Digital cockpit (facelift variants)"
|
| 50 |
+
Tata Harrier:
|
| 51 |
+
years: "2019–present (facelift 2023)"
|
| 52 |
+
body: "Mid-size SUV"
|
| 53 |
+
engines:
|
| 54 |
+
- "2.0L Kryotec Diesel"
|
| 55 |
+
features:
|
| 56 |
+
- "Panoramic sunroof (variants)"
|
| 57 |
+
- "ADAS (facelift variants)"
|
| 58 |
+
Tata Safari:
|
| 59 |
+
years: "2021–present (facelift 2023)"
|
| 60 |
+
body: "3-row SUV"
|
| 61 |
+
engines:
|
| 62 |
+
- "2.0L Kryotec Diesel"
|
| 63 |
+
features:
|
| 64 |
+
- "Captain seats option"
|
| 65 |
+
- "ADAS (facelift variants)"
|
requirements.txt
CHANGED
|
@@ -1,15 +1,10 @@
|
|
| 1 |
torch>=2.1.0
|
| 2 |
torchvision>=0.16.0
|
| 3 |
-
torchaudio>=2.1.0
|
| 4 |
-
timm>=1.0.3
|
| 5 |
transformers>=4.42.0
|
| 6 |
tokenizers>=0.15.2
|
| 7 |
gradio>=4.36.1
|
| 8 |
-
pydantic>=2.7.0
|
| 9 |
pillow>=10.3.0
|
| 10 |
numpy>=1.26.4
|
| 11 |
-
pandas>=2.2.2
|
| 12 |
scikit-learn>=1.5.0
|
|
|
|
| 13 |
pyyaml>=6.0.1
|
| 14 |
-
reportlab>=4.1.0
|
| 15 |
-
ics>=0.7.2
|
|
|
|
| 1 |
torch>=2.1.0
|
| 2 |
torchvision>=0.16.0
|
|
|
|
|
|
|
| 3 |
transformers>=4.42.0
|
| 4 |
tokenizers>=0.15.2
|
| 5 |
gradio>=4.36.1
|
|
|
|
| 6 |
pillow>=10.3.0
|
| 7 |
numpy>=1.26.4
|
|
|
|
| 8 |
scikit-learn>=1.5.0
|
| 9 |
+
scikit-image>=0.23.2
|
| 10 |
pyyaml>=6.0.1
|
|
|
|
|
|
tata_id/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.1.0'
|
tata_id/autofill.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
def load_specs(path: str = "data/tata_specs.yaml") -> Dict[str, Any]:
|
| 5 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 6 |
+
return yaml.safe_load(f)
|
| 7 |
+
|
| 8 |
+
def autofill_details(model_name: str, specs: Dict[str, Any]) -> Dict[str, Any]:
|
| 9 |
+
info = (specs.get("models", {}) or {}).get(model_name, {})
|
| 10 |
+
return {
|
| 11 |
+
"model": model_name,
|
| 12 |
+
"years": info.get("years", "N/A"),
|
| 13 |
+
"body": info.get("body", "N/A"),
|
| 14 |
+
"engines": info.get("engines", []),
|
| 15 |
+
"features": info.get("features", []),
|
| 16 |
+
}
|
tata_id/color.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from sklearn.cluster import KMeans
|
| 5 |
+
from skimage.color import rgb2lab
|
| 6 |
+
|
| 7 |
+
# A limited palette of common automotive colors with hex values
|
| 8 |
+
NAMED_COLORS = {
|
| 9 |
+
"White": (255,255,255),
|
| 10 |
+
"Black": (0,0,0),
|
| 11 |
+
"Silver": (192,192,192),
|
| 12 |
+
"Grey": (128,128,128),
|
| 13 |
+
"Red": (200,0,0),
|
| 14 |
+
"Blue": (0,80,180),
|
| 15 |
+
"Dark Blue": (0,40,100),
|
| 16 |
+
"Green": (0,150,0),
|
| 17 |
+
"Dark Green": (0,90,0),
|
| 18 |
+
"Yellow": (240,210,0),
|
| 19 |
+
"Orange": (255,130,0),
|
| 20 |
+
"Brown": (120,70,25),
|
| 21 |
+
"Beige": (210,190,150),
|
| 22 |
+
"Teal": (0,120,120),
|
| 23 |
+
"Purple": (110,0,140),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def _nearest_named_color(rgb: Tuple[int,int,int]) -> str:
|
| 27 |
+
r,g,b = rgb
|
| 28 |
+
best = None; best_d = 1e9
|
| 29 |
+
for name, (R,G,B) in NAMED_COLORS.items():
|
| 30 |
+
d = (r-R)**2 + (g-G)**2 + (b-B)**2
|
| 31 |
+
if d < best_d:
|
| 32 |
+
best_d = d; best = name
|
| 33 |
+
return best
|
| 34 |
+
|
| 35 |
+
def detect_color(image: Image.Image, n_clusters: int = 4) -> Dict:
|
| 36 |
+
# Downsample to speed
|
| 37 |
+
img = image.convert("RGB").resize((256,256))
|
| 38 |
+
arr = np.array(img).reshape(-1,3).astype(np.float32)
|
| 39 |
+
|
| 40 |
+
# Filter near extreme dark/light pixels (often background/sun glare)
|
| 41 |
+
mask = (arr.mean(axis=1) > 25) & (arr.mean(axis=1) < 245)
|
| 42 |
+
arr = arr[mask]
|
| 43 |
+
if len(arr) < 100:
|
| 44 |
+
arr = np.array(img).reshape(-1,3).astype(np.float32)
|
| 45 |
+
|
| 46 |
+
# KMeans on RGB
|
| 47 |
+
km = KMeans(n_clusters=n_clusters, n_init=4, random_state=42).fit(arr)
|
| 48 |
+
centers = km.cluster_centers_.astype(int)
|
| 49 |
+
labels, counts = np.unique(km.labels_, return_counts=True)
|
| 50 |
+
idx = int(labels[np.argmax(counts)])
|
| 51 |
+
dom_rgb = tuple(map(int, centers[idx]))
|
| 52 |
+
dom_name = _nearest_named_color(dom_rgb)
|
| 53 |
+
return {"name": dom_name, "rgb": dom_rgb, "hex": "#%02x%02x%02x" % dom_rgb}
|
tata_id/kb.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Known Tata models list used for zero-shot classification prompts
|
| 2 |
+
MODEL_LIST = [
|
| 3 |
+
"Tata Tiago",
|
| 4 |
+
"Tata Tigor",
|
| 5 |
+
"Tata Altroz",
|
| 6 |
+
"Tata Punch",
|
| 7 |
+
"Tata Nexon",
|
| 8 |
+
"Tata Harrier",
|
| 9 |
+
"Tata Safari",
|
| 10 |
+
]
|
tata_id/model.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 5 |
+
from .kb import MODEL_LIST
|
| 6 |
+
|
| 7 |
+
class TataModelIdentifier:
|
| 8 |
+
"""Zero-shot classifier using CLIP. Fine-tune later with training script for higher accuracy."""
|
| 9 |
+
def __init__(self, candidate_models: List[str] = None, device: str = None):
|
| 10 |
+
self.labels = candidate_models or MODEL_LIST
|
| 11 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
|
| 13 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def predict_topk(self, image: Image.Image, k: int = 3) -> List[Tuple[str, float]]:
|
| 17 |
+
prompts = [f"A photo of a {name}" for name in self.labels]
|
| 18 |
+
inputs = self.processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True).to(self.device)
|
| 19 |
+
out = self.model(**inputs)
|
| 20 |
+
# logits_per_image: [1, num_text]
|
| 21 |
+
logits = out.logits_per_image[0].softmax(dim=-1)
|
| 22 |
+
probs = logits.detach().cpu().tolist()
|
| 23 |
+
pairs = list(zip(self.labels, probs))
|
| 24 |
+
pairs.sort(key=lambda x: x[1], reverse=True)
|
| 25 |
+
return pairs[:k]
|
tata_id/utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64, io
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
def file_to_image(obj):
|
| 5 |
+
# Accept (name, bytes) tuples, raw bytes, base64 data-urls, or file paths
|
| 6 |
+
if isinstance(obj, tuple) and len(obj) == 2 and isinstance(obj[1], (bytes, bytearray)):
|
| 7 |
+
return Image.open(io.BytesIO(obj[1]))
|
| 8 |
+
if isinstance(obj, (bytes, bytearray)):
|
| 9 |
+
return Image.open(io.BytesIO(obj))
|
| 10 |
+
if isinstance(obj, str):
|
| 11 |
+
if obj.startswith("data:"):
|
| 12 |
+
b64 = obj.split(",",1)[-1]
|
| 13 |
+
return Image.open(io.BytesIO(base64.b64decode(b64)))
|
| 14 |
+
return Image.open(obj)
|
| 15 |
+
if hasattr(obj, "read"):
|
| 16 |
+
return Image.open(obj)
|
| 17 |
+
# Gradio may pass dicts
|
| 18 |
+
if isinstance(obj, dict) and "data" in obj:
|
| 19 |
+
return Image.open(io.BytesIO(obj["data"]))
|
| 20 |
+
raise ValueError("Unsupported file object for image decoding")
|
training/train_classifier.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os, pandas as pd
|
| 2 |
+
import torch, torch.nn as nn
|
| 3 |
+
import torchvision.transforms as T
|
| 4 |
+
from torch.utils.data import DataLoader, random_split, Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import timm
|
| 7 |
+
|
| 8 |
+
class CarsDataset(Dataset):
|
| 9 |
+
def __init__(self, csv_path, img_root):
|
| 10 |
+
self.df = pd.read_csv(csv_path)
|
| 11 |
+
self.img_root = img_root
|
| 12 |
+
self.labels = sorted(self.df['label'].unique().tolist())
|
| 13 |
+
self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()])
|
| 14 |
+
self.label_to_idx = {l:i for i,l in enumerate(self.labels)}
|
| 15 |
+
def __len__(self): return len(self.df)
|
| 16 |
+
def __getitem__(self, i):
|
| 17 |
+
row = self.df.iloc[i]
|
| 18 |
+
p = row['image_path']
|
| 19 |
+
if not os.path.isabs(p):
|
| 20 |
+
p = os.path.join(self.img_root, p)
|
| 21 |
+
img = Image.open(p).convert("RGB")
|
| 22 |
+
x = self.transform(img)
|
| 23 |
+
y = self.label_to_idx[row['label']]
|
| 24 |
+
return x, y
|
| 25 |
+
|
| 26 |
+
def main(args):
|
| 27 |
+
ds = CarsDataset(args.annotations, os.path.dirname(args.annotations))
|
| 28 |
+
n = len(ds); n_val = max(1, int(0.2*n))
|
| 29 |
+
tr, va = random_split(ds, [n-n_val, n_val])
|
| 30 |
+
tl = DataLoader(tr, batch_size=32, shuffle=True)
|
| 31 |
+
vl = DataLoader(va, batch_size=32)
|
| 32 |
+
|
| 33 |
+
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=len(ds.labels))
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
model.to(device)
|
| 36 |
+
|
| 37 |
+
opt = torch.optim.AdamW(model.parameters(), lr=2e-4)
|
| 38 |
+
crit = nn.CrossEntropyLoss()
|
| 39 |
+
|
| 40 |
+
best = 0.0
|
| 41 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
for epoch in range(args.epochs):
|
| 44 |
+
model.train()
|
| 45 |
+
for xb, yb in tl:
|
| 46 |
+
xb = xb.to(device); yb = yb.to(device)
|
| 47 |
+
opt.zero_grad(); out = model(xb); loss = crit(out, yb)
|
| 48 |
+
loss.backward(); opt.step()
|
| 49 |
+
# val
|
| 50 |
+
model.eval(); corr=0; tot=0
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
for xb, yb in vl:
|
| 53 |
+
xb = xb.to(device); yb = yb.to(device)
|
| 54 |
+
pred = model(xb).argmax(1)
|
| 55 |
+
corr += (pred==yb).sum().item(); tot += yb.numel()
|
| 56 |
+
acc = corr/tot if tot else 0
|
| 57 |
+
print(f"Epoch {epoch+1}: val_acc={acc:.3f}")
|
| 58 |
+
if acc > best:
|
| 59 |
+
best = acc
|
| 60 |
+
torch.save({"model": model.state_dict(), "labels": ds.labels}, os.path.join(args.out_dir, "best.pt"))
|
| 61 |
+
print("Done. Best acc:", best)
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
ap = argparse.ArgumentParser()
|
| 65 |
+
ap.add_argument("--data_root", required=True)
|
| 66 |
+
ap.add_argument("--annotations", required=True)
|
| 67 |
+
ap.add_argument("--out_dir", default="checkpoints/vision")
|
| 68 |
+
ap.add_argument("--epochs", type=int, default=10)
|
| 69 |
+
args = ap.parse_args()
|
| 70 |
+
main(args)
|