from pathlib import Path import json import re import gradio as gr import pandas as pd import torch import torch.nn as nn from datasets import load_dataset from PIL import Image, ImageFile from torchvision import models, transforms ImageFile.LOAD_TRUNCATED_IMAGES = True PROJECT_ROOT = Path(__file__).resolve().parent ARTIFACTS_DIR = PROJECT_ROOT / "artifacts" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EVAL_TRANSFORM = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) CHECKPOINT_PATHS = { "BaselineCNN": PROJECT_ROOT / "baseline_cnn_best.pt", "ResNet18": PROJECT_ROOT / "resnet18_best.pt", "ResNet34": PROJECT_ROOT / "resnet34_best.pt", "ResNet50": PROJECT_ROOT / "resnet50_best.pt", "ResNet101": PROJECT_ROOT / "resnet101_best.pt", "ResNet152": PROJECT_ROOT / "resnet152_best.pt", } class BaselineCNN(nn.Module): def __init__(self, classes: int): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.3), nn.Linear(256, classes), ) def forward(self, x): return self.classifier(self.features(x)) def build_resnet(name: str, classes: int): if name == "ResNet18": model = models.resnet18(weights=None) elif name == "ResNet34": model = models.resnet34(weights=None) elif name == "ResNet50": model = models.resnet50(weights=None) elif name == "ResNet101": model = models.resnet101(weights=None) elif name == "ResNet152": model = models.resnet152(weights=None) else: raise ValueError(f"Unsupported model name: {name}") model.fc = nn.Linear(model.fc.in_features, classes) return model def load_class_names(): class_names_path = ARTIFACTS_DIR / "class_names.json" if not class_names_path.exists(): raise FileNotFoundError( "artifacts/class_names.json was not found. Run the notebook first so it can export deployment artifacts." ) with open(class_names_path, "r", encoding="utf-8") as file: class_names = json.load(file) if not class_names or len(class_names) <= 1: raise ValueError("class_names.json is empty or invalid.") return class_names def resolve_best_model_name(): best_model_path = ARTIFACTS_DIR / "best_model_name.txt" if best_model_path.exists(): name = best_model_path.read_text(encoding="utf-8").strip() if name in CHECKPOINT_PATHS and CHECKPOINT_PATHS[name].exists(): return name for candidate in ["ResNet152", "ResNet101", "ResNet50", "ResNet34", "ResNet18", "BaselineCNN"]: if CHECKPOINT_PATHS[candidate].exists(): return candidate raise FileNotFoundError("No checkpoint files were found next to app.py.") def load_model(best_model_name: str, num_classes: int): if best_model_name == "BaselineCNN": model = BaselineCNN(num_classes) else: model = build_resnet(best_model_name, num_classes) state_dict = torch.load(CHECKPOINT_PATHS[best_model_name], map_location=DEVICE) model.load_state_dict(state_dict) model = model.to(DEVICE) model.eval() return model class_names = load_class_names() best_model_name = resolve_best_model_name() model = load_model(best_model_name, len(class_names)) def predict_pil_image(image, top_k=5): if image is None: return pd.DataFrame(columns=["Class", "Probability"]) if image.mode != "RGB": image = image.convert("RGB") image_tensor = EVAL_TRANSFORM(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(image_tensor) probabilities = torch.softmax(logits, dim=1).squeeze(0) top_k = min(top_k, len(class_names)) top_probs, top_indices = torch.topk(probabilities, k=top_k) rows = [] for idx, prob in zip(top_indices.tolist(), top_probs.tolist()): rows.append({ "Class": class_names[idx], "Probability": float(prob), }) return pd.DataFrame(rows) LISTING_TABLE = None def normalize_text(text): text = "" if text is None else str(text).lower().strip() text = "".join(ch if ch.isalnum() else " " for ch in text) return " ".join(text.split()) def load_listing_dataset(): global LISTING_TABLE if LISTING_TABLE is not None: return LISTING_TABLE try: frame = load_dataset("rebrowser/carguruscom-dataset", "car-listings", split="train").to_pandas() except Exception: LISTING_TABLE = pd.DataFrame() return LISTING_TABLE keep_columns = [ "year", "make", "model", "trim", "bodyStyle", "price", "mileage", "transmission", "drivetrain", "fuelType", "dealRatingKey", "sellerCity", "sellerState", "listingUrl", "description" ] keep_columns = [column for column in keep_columns if column in frame.columns] frame = frame[keep_columns].copy() for column in ["year", "price", "mileage"]: if column in frame.columns: frame[column] = pd.to_numeric(frame[column], errors="coerce") for column in [ "make", "model", "trim", "bodyStyle", "transmission", "drivetrain", "fuelType", "dealRatingKey", "sellerCity", "sellerState", "listingUrl", "description" ]: if column in frame.columns: frame[column] = frame[column].fillna("").astype(str) frame["make_norm"] = frame.get("make", pd.Series(index=frame.index, dtype=str)).apply(normalize_text) frame["model_norm"] = frame.get("model", pd.Series(index=frame.index, dtype=str)).apply(normalize_text) LISTING_TABLE = frame return LISTING_TABLE def parse_predicted_car_label(class_name: str): frame = load_listing_dataset() year_match = re.search(r"(19\d{2}|20\d{2})", class_name) year = int(year_match.group(1)) if year_match else None if frame.empty: return {"year": year, "make": None, "model": None, "body_style": None} label = normalize_text(class_name) matched_make = None matched_model = None body_style = None makes = sorted( [value for value in frame["make"].dropna().unique() if str(value).strip()], key=lambda value: len(str(value)), reverse=True ) for value in makes: if normalize_text(value) in label: matched_make = value break if matched_make is not None: models_for_make = frame[frame["make"] == matched_make]["model"].dropna().unique().tolist() models_for_make = sorted( [value for value in models_for_make if str(value).strip()], key=lambda value: len(str(value)), reverse=True ) for value in models_for_make: if normalize_text(value) in label: matched_model = value break for value in ["sedan", "coupe", "convertible", "suv", "wagon", "hatchback", "minivan", "van", "pickup", "truck"]: if value in label: body_style = value break return {"year": year, "make": matched_make, "model": matched_model, "body_style": body_style} def find_matching_listings(make=None, model=None, year=None, body_style=None, max_results=12): frame = load_listing_dataset() if frame.empty: return pd.DataFrame() filtered = frame.copy() if make: filtered = filtered[filtered["make_norm"] == normalize_text(make)] if model: model_norm = normalize_text(model) exact_match = filtered[filtered["model_norm"] == model_norm] if len(exact_match) > 0: filtered = exact_match else: filtered = filtered[filtered["model_norm"].str.contains(model_norm, na=False)] if year is not None and "year" in filtered.columns: filtered = filtered[filtered["year"].between(year - 1, year + 1, inclusive="both")] if body_style and "bodyStyle" in filtered.columns: filtered = filtered[filtered["bodyStyle"].str.contains(body_style, case=False, na=False)] if len(filtered) == 0 and make: filtered = frame[frame["make_norm"] == normalize_text(make)].copy() if len(filtered) == 0: return filtered filtered = filtered.copy() filtered["year_distance"] = 0 if year is None else (filtered["year"] - year).abs() filtered["deal_rank"] = filtered.get("dealRatingKey", pd.Series("NA", index=filtered.index)).map({ "GREAT_PRICE": 0, "GOOD_PRICE": 1, "FAIR_PRICE": 2, "POOR_PRICE": 3, "OVERPRICED": 4, "OUTLIER": 5, "NA": 6, }).fillna(6) filtered = filtered.sort_values( ["year_distance", "deal_rank", "price", "mileage"], ascending=[True, True, True, True] ) show_columns = [ column for column in [ "year", "make", "model", "trim", "bodyStyle", "price", "mileage", "transmission", "drivetrain", "fuelType", "dealRatingKey", "sellerCity", "sellerState", "listingUrl" ] if column in filtered.columns ] return filtered[show_columns].head(max_results).reset_index(drop=True) def build_listing_summary(frame, parsed_car): if frame is None or len(frame) == 0: return "No matching marketplace listings were found." lines = [f"Matched listings: {len(frame)}"] if parsed_car.get("make"): lines.append(f"Make: {parsed_car['make']}") if parsed_car.get("model"): lines.append(f"Model: {parsed_car['model']}") if parsed_car.get("year"): lines.append(f"Target year: {parsed_car['year']}") if "price" in frame.columns and frame["price"].notna().any(): lines.append(f"Price range: ${int(frame['price'].min()):,} - ${int(frame['price'].max()):,}") if "mileage" in frame.columns and frame["mileage"].notna().any(): lines.append(f"Mileage range: {int(frame['mileage'].min()):,} - {int(frame['mileage'].max()):,} miles") return "\n".join(lines) def format_listing_table(frame): if frame is None or len(frame) == 0: return pd.DataFrame( columns=[ "Year", "Make", "Model", "Trim", "Body Style", "Price", "Mileage", "Transmission", "Drivetrain", "Fuel Type", "Deal Rating", "City", "State", "Listing URL", ] ) frame = frame.copy().rename( columns={ "year": "Year", "make": "Make", "model": "Model", "trim": "Trim", "bodyStyle": "Body Style", "price": "Price", "mileage": "Mileage", "transmission": "Transmission", "drivetrain": "Drivetrain", "fuelType": "Fuel Type", "dealRatingKey": "Deal Rating", "sellerCity": "City", "sellerState": "State", "listingUrl": "Listing URL", } ) if "Price" in frame.columns: frame["Price"] = frame["Price"].apply( lambda value: "—" if pd.isna(value) else f"${int(value):,}" ) if "Mileage" in frame.columns: frame["Mileage"] = frame["Mileage"].apply( lambda value: "—" if pd.isna(value) else f"{int(value):,} mi" ) order = [ "Year", "Make", "Model", "Trim", "Body Style", "Price", "Mileage", "Transmission", "Drivetrain", "Fuel Type", "Deal Rating", "City", "State", "Listing URL", ] order = [column for column in order if column in frame.columns] return frame[order].reset_index(drop=True) def run_demo(image): if image is None: return ( "Please upload a car image.", pd.DataFrame(), "Marketplace summary will appear here.", format_listing_table(pd.DataFrame()), ) predictions = predict_pil_image(image) parsed_car = parse_predicted_car_label(predictions.iloc[0]["Class"]) listings = find_matching_listings( make=parsed_car.get("make"), model=parsed_car.get("model"), year=parsed_car.get("year"), body_style=parsed_car.get("body_style"), max_results=12, ) summary = ( f"Best model: {best_model_name}\n" f"Top prediction: {predictions.iloc[0]['Class']}\n" f"Confidence: {predictions.iloc[0]['Probability']:.4f}" ) listing_summary = build_listing_summary(listings, parsed_car) return summary, predictions, listing_summary, format_listing_table(listings) simple_css = """ :root { --sc-accent: #C0504D; --sc-accent-dark: #A2413F; --sc-accent-soft: #F6E7E6; --sc-white: #FFFFFF; --sc-bg: #F4F5F7; --sc-surface: #FAFAFB; --sc-graphite: #2C2C31; --sc-graphite-2: #3A3A40; --sc-border: #D9DDE2; --sc-muted: #6D7278; --sc-text: #202327; } .gradio-container { background: linear-gradient(180deg, var(--sc-graphite) 0 118px, var(--sc-bg) 118px 100%); font-family: Arial, Helvetica, sans-serif !important; color: var(--sc-text); } #page { max-width: 1220px; margin: 0 auto; padding: 24px 18px 40px 18px; } #topbar { display: flex; align-items: center; justify-content: space-between; gap: 24px; color: white; margin-bottom: 20px; } #brand { display: flex; flex-direction: column; gap: 4px; } #brand h1 { margin: 0; font-size: 24px; line-height: 1; letter-spacing: 0.01em; font-style: italic; text-transform: uppercase; font-weight: 900; color: white; } #brand h1 span { color: var(--sc-accent); } #brand p { margin: 0; font-size: 12px; letter-spacing: 0.04em; color: #C9CDD2; text-transform: uppercase; } #hero { background: linear-gradient(135deg, rgba(255,255,255,0.06), rgba(255,255,255,0.02)), linear-gradient(180deg, var(--sc-graphite-2), var(--sc-graphite)); color: white; padding: 22px 26px; margin-bottom: 18px; border-left: 4px solid var(--sc-accent); box-shadow: 0 12px 30px rgba(20, 20, 24, 0.15); } #hero p { margin: 0; max-width: 860px; font-size: 15px; color: #D5D8DC; line-height: 1.55; } .panel { background: var(--sc-white); border: 1px solid var(--sc-border); box-shadow: 0 10px 24px rgba(32, 35, 39, 0.06); padding: 14px; } .section-label { margin: 0 0 8px 0; font-size: 12px; text-transform: uppercase; letter-spacing: 0.08em; color: var(--sc-muted); font-weight: 700; } .section-label-tight { margin: 0 0 4px 0; font-size: 12px; text-transform: uppercase; letter-spacing: 0.08em; color: var(--sc-muted); font-weight: 700; } .highlight-card { background: linear-gradient(180deg, #34363B, #25272B); border: 1px solid #43464D; color: white; padding: 12px; } .highlight-card textarea, .highlight-card input { background: transparent !important; color: white !important; } button.primary { background: var(--sc-accent) !important; color: white !important; border: 1px solid var(--sc-accent-dark) !important; border-radius: 0 !important; font-weight: 700 !important; text-transform: uppercase; letter-spacing: 0.04em; box-shadow: none !important; } button.primary:hover { background: #AA4542 !important; } .gradio-container button.secondary, .gradio-container .block, .gradio-container .gr-box, .gradio-image, .gradio-dataframe, textarea, input { border-radius: 0 !important; } .gradio-image { border: 1px solid var(--sc-border) !important; background: white !important; } .gradio-dataframe table thead tr th { background: var(--sc-graphite) !important; color: white !important; border-color: var(--sc-graphite) !important; font-weight: 700 !important; } .gradio-dataframe table tbody tr:nth-child(even) td { background: var(--sc-surface) !important; } .gradio-dataframe table tbody tr td { border-color: var(--sc-border) !important; } .gradio-container .wrap.svelte-1ipelgc, .gradio-container .contain { background: transparent !important; } #results-grid { gap: 18px; align-items: start; } .compact-box { margin-top: 0 !important; padding-top: 0 !important; } @media (max-width: 900px) { #topbar { flex-direction: column; align-items: flex-start; } } """ with gr.Blocks(css=simple_css) as demo: with gr.Column(elem_id="page"): gr.HTML( """

StanfordCars

Image Classification Capstone Project

Upload an image from a computer and get a prediction view with model confidence and matched marketplace listings.

""" ) with gr.Row(elem_id="results-grid"): with gr.Column(scale=5): with gr.Column(elem_classes=["panel", "highlight-card"]): gr.HTML('
Image Upload
') image_input = gr.Image( type="pil", sources=["upload"], label="Upload a car image from your computer", height=360, ) predict_button = gr.Button("Predict", variant="primary") with gr.Column(scale=7): with gr.Column(elem_classes=["panel", "highlight-card"]): gr.HTML('
Prediction Summary
') summary_output = gr.Textbox(label="", show_label=False, lines=3) with gr.Column(elem_classes=["panel", "highlight-card"]): gr.HTML('
Marketplace Summary
') listing_summary_output = gr.Textbox(label="", show_label=False, lines=5) with gr.Column(elem_classes=["panel", "highlight-card"]): gr.HTML('
Top Predictions
') predictions_output = gr.Dataframe(label="", show_label=False, interactive=False) with gr.Column(elem_classes=["panel", "highlight-card"]): gr.HTML('
Matched Marketplace Listings
') listings_output = gr.Dataframe( label="", show_label=False, interactive=False, wrap=True, ) predict_button.click( fn=run_demo, inputs=image_input, outputs=[ summary_output, predictions_output, listing_summary_output, listings_output, ], ) if __name__ == "__main__": demo.launch()