alexxtech's picture
Update app.py
ecb7f82 verified
from pathlib import Path
import json
import re
import gradio as gr
import pandas as pd
import torch
import torch.nn as nn
from datasets import load_dataset
from PIL import Image, ImageFile
from torchvision import models, transforms
ImageFile.LOAD_TRUNCATED_IMAGES = True
PROJECT_ROOT = Path(__file__).resolve().parent
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EVAL_TRANSFORM = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
CHECKPOINT_PATHS = {
"BaselineCNN": PROJECT_ROOT / "baseline_cnn_best.pt",
"ResNet18": PROJECT_ROOT / "resnet18_best.pt",
"ResNet34": PROJECT_ROOT / "resnet34_best.pt",
"ResNet50": PROJECT_ROOT / "resnet50_best.pt",
"ResNet101": PROJECT_ROOT / "resnet101_best.pt",
"ResNet152": PROJECT_ROOT / "resnet152_best.pt",
}
class BaselineCNN(nn.Module):
def __init__(self, classes: int):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.3),
nn.Linear(256, classes),
)
def forward(self, x):
return self.classifier(self.features(x))
def build_resnet(name: str, classes: int):
if name == "ResNet18":
model = models.resnet18(weights=None)
elif name == "ResNet34":
model = models.resnet34(weights=None)
elif name == "ResNet50":
model = models.resnet50(weights=None)
elif name == "ResNet101":
model = models.resnet101(weights=None)
elif name == "ResNet152":
model = models.resnet152(weights=None)
else:
raise ValueError(f"Unsupported model name: {name}")
model.fc = nn.Linear(model.fc.in_features, classes)
return model
def load_class_names():
class_names_path = ARTIFACTS_DIR / "class_names.json"
if not class_names_path.exists():
raise FileNotFoundError(
"artifacts/class_names.json was not found. Run the notebook first so it can export deployment artifacts."
)
with open(class_names_path, "r", encoding="utf-8") as file:
class_names = json.load(file)
if not class_names or len(class_names) <= 1:
raise ValueError("class_names.json is empty or invalid.")
return class_names
def resolve_best_model_name():
best_model_path = ARTIFACTS_DIR / "best_model_name.txt"
if best_model_path.exists():
name = best_model_path.read_text(encoding="utf-8").strip()
if name in CHECKPOINT_PATHS and CHECKPOINT_PATHS[name].exists():
return name
for candidate in ["ResNet152", "ResNet101", "ResNet50", "ResNet34", "ResNet18", "BaselineCNN"]:
if CHECKPOINT_PATHS[candidate].exists():
return candidate
raise FileNotFoundError("No checkpoint files were found next to app.py.")
def load_model(best_model_name: str, num_classes: int):
if best_model_name == "BaselineCNN":
model = BaselineCNN(num_classes)
else:
model = build_resnet(best_model_name, num_classes)
state_dict = torch.load(CHECKPOINT_PATHS[best_model_name], map_location=DEVICE)
model.load_state_dict(state_dict)
model = model.to(DEVICE)
model.eval()
return model
class_names = load_class_names()
best_model_name = resolve_best_model_name()
model = load_model(best_model_name, len(class_names))
def predict_pil_image(image, top_k=5):
if image is None:
return pd.DataFrame(columns=["Class", "Probability"])
if image.mode != "RGB":
image = image.convert("RGB")
image_tensor = EVAL_TRANSFORM(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(image_tensor)
probabilities = torch.softmax(logits, dim=1).squeeze(0)
top_k = min(top_k, len(class_names))
top_probs, top_indices = torch.topk(probabilities, k=top_k)
rows = []
for idx, prob in zip(top_indices.tolist(), top_probs.tolist()):
rows.append({
"Class": class_names[idx],
"Probability": float(prob),
})
return pd.DataFrame(rows)
LISTING_TABLE = None
def normalize_text(text):
text = "" if text is None else str(text).lower().strip()
text = "".join(ch if ch.isalnum() else " " for ch in text)
return " ".join(text.split())
def load_listing_dataset():
global LISTING_TABLE
if LISTING_TABLE is not None:
return LISTING_TABLE
try:
frame = load_dataset("rebrowser/carguruscom-dataset", "car-listings", split="train").to_pandas()
except Exception:
LISTING_TABLE = pd.DataFrame()
return LISTING_TABLE
keep_columns = [
"year", "make", "model", "trim", "bodyStyle", "price", "mileage",
"transmission", "drivetrain", "fuelType", "dealRatingKey",
"sellerCity", "sellerState", "listingUrl", "description"
]
keep_columns = [column for column in keep_columns if column in frame.columns]
frame = frame[keep_columns].copy()
for column in ["year", "price", "mileage"]:
if column in frame.columns:
frame[column] = pd.to_numeric(frame[column], errors="coerce")
for column in [
"make", "model", "trim", "bodyStyle", "transmission",
"drivetrain", "fuelType", "dealRatingKey",
"sellerCity", "sellerState", "listingUrl", "description"
]:
if column in frame.columns:
frame[column] = frame[column].fillna("").astype(str)
frame["make_norm"] = frame.get("make", pd.Series(index=frame.index, dtype=str)).apply(normalize_text)
frame["model_norm"] = frame.get("model", pd.Series(index=frame.index, dtype=str)).apply(normalize_text)
LISTING_TABLE = frame
return LISTING_TABLE
def parse_predicted_car_label(class_name: str):
frame = load_listing_dataset()
year_match = re.search(r"(19\d{2}|20\d{2})", class_name)
year = int(year_match.group(1)) if year_match else None
if frame.empty:
return {"year": year, "make": None, "model": None, "body_style": None}
label = normalize_text(class_name)
matched_make = None
matched_model = None
body_style = None
makes = sorted(
[value for value in frame["make"].dropna().unique() if str(value).strip()],
key=lambda value: len(str(value)),
reverse=True
)
for value in makes:
if normalize_text(value) in label:
matched_make = value
break
if matched_make is not None:
models_for_make = frame[frame["make"] == matched_make]["model"].dropna().unique().tolist()
models_for_make = sorted(
[value for value in models_for_make if str(value).strip()],
key=lambda value: len(str(value)),
reverse=True
)
for value in models_for_make:
if normalize_text(value) in label:
matched_model = value
break
for value in ["sedan", "coupe", "convertible", "suv", "wagon", "hatchback", "minivan", "van", "pickup", "truck"]:
if value in label:
body_style = value
break
return {"year": year, "make": matched_make, "model": matched_model, "body_style": body_style}
def find_matching_listings(make=None, model=None, year=None, body_style=None, max_results=12):
frame = load_listing_dataset()
if frame.empty:
return pd.DataFrame()
filtered = frame.copy()
if make:
filtered = filtered[filtered["make_norm"] == normalize_text(make)]
if model:
model_norm = normalize_text(model)
exact_match = filtered[filtered["model_norm"] == model_norm]
if len(exact_match) > 0:
filtered = exact_match
else:
filtered = filtered[filtered["model_norm"].str.contains(model_norm, na=False)]
if year is not None and "year" in filtered.columns:
filtered = filtered[filtered["year"].between(year - 1, year + 1, inclusive="both")]
if body_style and "bodyStyle" in filtered.columns:
filtered = filtered[filtered["bodyStyle"].str.contains(body_style, case=False, na=False)]
if len(filtered) == 0 and make:
filtered = frame[frame["make_norm"] == normalize_text(make)].copy()
if len(filtered) == 0:
return filtered
filtered = filtered.copy()
filtered["year_distance"] = 0 if year is None else (filtered["year"] - year).abs()
filtered["deal_rank"] = filtered.get("dealRatingKey", pd.Series("NA", index=filtered.index)).map({
"GREAT_PRICE": 0,
"GOOD_PRICE": 1,
"FAIR_PRICE": 2,
"POOR_PRICE": 3,
"OVERPRICED": 4,
"OUTLIER": 5,
"NA": 6,
}).fillna(6)
filtered = filtered.sort_values(
["year_distance", "deal_rank", "price", "mileage"],
ascending=[True, True, True, True]
)
show_columns = [
column for column in [
"year", "make", "model", "trim", "bodyStyle", "price", "mileage",
"transmission", "drivetrain", "fuelType", "dealRatingKey",
"sellerCity", "sellerState", "listingUrl"
] if column in filtered.columns
]
return filtered[show_columns].head(max_results).reset_index(drop=True)
def build_listing_summary(frame, parsed_car):
if frame is None or len(frame) == 0:
return "No matching marketplace listings were found."
lines = [f"Matched listings: {len(frame)}"]
if parsed_car.get("make"):
lines.append(f"Make: {parsed_car['make']}")
if parsed_car.get("model"):
lines.append(f"Model: {parsed_car['model']}")
if parsed_car.get("year"):
lines.append(f"Target year: {parsed_car['year']}")
if "price" in frame.columns and frame["price"].notna().any():
lines.append(f"Price range: ${int(frame['price'].min()):,} - ${int(frame['price'].max()):,}")
if "mileage" in frame.columns and frame["mileage"].notna().any():
lines.append(f"Mileage range: {int(frame['mileage'].min()):,} - {int(frame['mileage'].max()):,} miles")
return "\n".join(lines)
def format_listing_table(frame):
if frame is None or len(frame) == 0:
return pd.DataFrame(
columns=[
"Year",
"Make",
"Model",
"Trim",
"Body Style",
"Price",
"Mileage",
"Transmission",
"Drivetrain",
"Fuel Type",
"Deal Rating",
"City",
"State",
"Listing URL",
]
)
frame = frame.copy().rename(
columns={
"year": "Year",
"make": "Make",
"model": "Model",
"trim": "Trim",
"bodyStyle": "Body Style",
"price": "Price",
"mileage": "Mileage",
"transmission": "Transmission",
"drivetrain": "Drivetrain",
"fuelType": "Fuel Type",
"dealRatingKey": "Deal Rating",
"sellerCity": "City",
"sellerState": "State",
"listingUrl": "Listing URL",
}
)
if "Price" in frame.columns:
frame["Price"] = frame["Price"].apply(
lambda value: "—" if pd.isna(value) else f"${int(value):,}"
)
if "Mileage" in frame.columns:
frame["Mileage"] = frame["Mileage"].apply(
lambda value: "—" if pd.isna(value) else f"{int(value):,} mi"
)
order = [
"Year",
"Make",
"Model",
"Trim",
"Body Style",
"Price",
"Mileage",
"Transmission",
"Drivetrain",
"Fuel Type",
"Deal Rating",
"City",
"State",
"Listing URL",
]
order = [column for column in order if column in frame.columns]
return frame[order].reset_index(drop=True)
def run_demo(image):
if image is None:
return (
"Please upload a car image.",
pd.DataFrame(),
"Marketplace summary will appear here.",
format_listing_table(pd.DataFrame()),
)
predictions = predict_pil_image(image)
parsed_car = parse_predicted_car_label(predictions.iloc[0]["Class"])
listings = find_matching_listings(
make=parsed_car.get("make"),
model=parsed_car.get("model"),
year=parsed_car.get("year"),
body_style=parsed_car.get("body_style"),
max_results=12,
)
summary = (
f"Best model: {best_model_name}\n"
f"Top prediction: {predictions.iloc[0]['Class']}\n"
f"Confidence: {predictions.iloc[0]['Probability']:.4f}"
)
listing_summary = build_listing_summary(listings, parsed_car)
return summary, predictions, listing_summary, format_listing_table(listings)
simple_css = """
:root {
--sc-accent: #C0504D;
--sc-accent-dark: #A2413F;
--sc-accent-soft: #F6E7E6;
--sc-white: #FFFFFF;
--sc-bg: #F4F5F7;
--sc-surface: #FAFAFB;
--sc-graphite: #2C2C31;
--sc-graphite-2: #3A3A40;
--sc-border: #D9DDE2;
--sc-muted: #6D7278;
--sc-text: #202327;
}
.gradio-container {
background:
linear-gradient(180deg, var(--sc-graphite) 0 118px, var(--sc-bg) 118px 100%);
font-family: Arial, Helvetica, sans-serif !important;
color: var(--sc-text);
}
#page {
max-width: 1220px;
margin: 0 auto;
padding: 24px 18px 40px 18px;
}
#topbar {
display: flex;
align-items: center;
justify-content: space-between;
gap: 24px;
color: white;
margin-bottom: 20px;
}
#brand {
display: flex;
flex-direction: column;
gap: 4px;
}
#brand h1 {
margin: 0;
font-size: 24px;
line-height: 1;
letter-spacing: 0.01em;
font-style: italic;
text-transform: uppercase;
font-weight: 900;
color: white;
}
#brand h1 span {
color: var(--sc-accent);
}
#brand p {
margin: 0;
font-size: 12px;
letter-spacing: 0.04em;
color: #C9CDD2;
text-transform: uppercase;
}
#hero {
background:
linear-gradient(135deg, rgba(255,255,255,0.06), rgba(255,255,255,0.02)),
linear-gradient(180deg, var(--sc-graphite-2), var(--sc-graphite));
color: white;
padding: 22px 26px;
margin-bottom: 18px;
border-left: 4px solid var(--sc-accent);
box-shadow: 0 12px 30px rgba(20, 20, 24, 0.15);
}
#hero p {
margin: 0;
max-width: 860px;
font-size: 15px;
color: #D5D8DC;
line-height: 1.55;
}
.panel {
background: var(--sc-white);
border: 1px solid var(--sc-border);
box-shadow: 0 10px 24px rgba(32, 35, 39, 0.06);
padding: 14px;
}
.section-label {
margin: 0 0 8px 0;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 0.08em;
color: var(--sc-muted);
font-weight: 700;
}
.section-label-tight {
margin: 0 0 4px 0;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 0.08em;
color: var(--sc-muted);
font-weight: 700;
}
.highlight-card {
background: linear-gradient(180deg, #34363B, #25272B);
border: 1px solid #43464D;
color: white;
padding: 12px;
}
.highlight-card textarea,
.highlight-card input {
background: transparent !important;
color: white !important;
}
button.primary {
background: var(--sc-accent) !important;
color: white !important;
border: 1px solid var(--sc-accent-dark) !important;
border-radius: 0 !important;
font-weight: 700 !important;
text-transform: uppercase;
letter-spacing: 0.04em;
box-shadow: none !important;
}
button.primary:hover {
background: #AA4542 !important;
}
.gradio-container button.secondary,
.gradio-container .block,
.gradio-container .gr-box,
.gradio-image,
.gradio-dataframe,
textarea,
input {
border-radius: 0 !important;
}
.gradio-image {
border: 1px solid var(--sc-border) !important;
background: white !important;
}
.gradio-dataframe table thead tr th {
background: var(--sc-graphite) !important;
color: white !important;
border-color: var(--sc-graphite) !important;
font-weight: 700 !important;
}
.gradio-dataframe table tbody tr:nth-child(even) td {
background: var(--sc-surface) !important;
}
.gradio-dataframe table tbody tr td {
border-color: var(--sc-border) !important;
}
.gradio-container .wrap.svelte-1ipelgc,
.gradio-container .contain {
background: transparent !important;
}
#results-grid {
gap: 18px;
align-items: start;
}
.compact-box {
margin-top: 0 !important;
padding-top: 0 !important;
}
@media (max-width: 900px) {
#topbar {
flex-direction: column;
align-items: flex-start;
}
}
"""
with gr.Blocks(css=simple_css) as demo:
with gr.Column(elem_id="page"):
gr.HTML(
"""
<div id="topbar">
<div id="brand">
<h1>Stanford<span>Cars</span></h1>
<p>Image Classification Capstone Project</p>
</div>
</div>
<div id="hero">
<p>
Upload an image from a computer and get a prediction view with model confidence and matched marketplace listings.
</p>
</div>
"""
)
with gr.Row(elem_id="results-grid"):
with gr.Column(scale=5):
with gr.Column(elem_classes=["panel", "highlight-card"]):
gr.HTML('<div class="section-label" style="color:#CFD3D8;">Image Upload</div>')
image_input = gr.Image(
type="pil",
sources=["upload"],
label="Upload a car image from your computer",
height=360,
)
predict_button = gr.Button("Predict", variant="primary")
with gr.Column(scale=7):
with gr.Column(elem_classes=["panel", "highlight-card"]):
gr.HTML('<div class="section-label" style="color:#CFD3D8;">Prediction Summary</div>')
summary_output = gr.Textbox(label="", show_label=False, lines=3)
with gr.Column(elem_classes=["panel", "highlight-card"]):
gr.HTML('<div class="section-label" style="color:#CFD3D8;">Marketplace Summary</div>')
listing_summary_output = gr.Textbox(label="", show_label=False, lines=5)
with gr.Column(elem_classes=["panel", "highlight-card"]):
gr.HTML('<div class="section-label-tight" style="color:#CFD3D8;">Top Predictions</div>')
predictions_output = gr.Dataframe(label="", show_label=False, interactive=False)
with gr.Column(elem_classes=["panel", "highlight-card"]):
gr.HTML('<div class="section-label" style="color:#CFD3D8;">Matched Marketplace Listings</div>')
listings_output = gr.Dataframe(
label="",
show_label=False,
interactive=False,
wrap=True,
)
predict_button.click(
fn=run_demo,
inputs=image_input,
outputs=[
summary_output,
predictions_output,
listing_summary_output,
listings_output,
],
)
if __name__ == "__main__":
demo.launch()