| from pathlib import Path |
| import json |
| import re |
|
|
| import gradio as gr |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from datasets import load_dataset |
| from PIL import Image, ImageFile |
| from torchvision import models, transforms |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent |
| ARTIFACTS_DIR = PROJECT_ROOT / "artifacts" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| EVAL_TRANSFORM = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| CHECKPOINT_PATHS = { |
| "BaselineCNN": PROJECT_ROOT / "baseline_cnn_best.pt", |
| "ResNet18": PROJECT_ROOT / "resnet18_best.pt", |
| "ResNet34": PROJECT_ROOT / "resnet34_best.pt", |
| "ResNet50": PROJECT_ROOT / "resnet50_best.pt", |
| "ResNet101": PROJECT_ROOT / "resnet101_best.pt", |
| "ResNet152": PROJECT_ROOT / "resnet152_best.pt", |
| } |
|
|
|
|
| class BaselineCNN(nn.Module): |
| def __init__(self, classes: int): |
| super().__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 32, 3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
|
|
| nn.Conv2d(32, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
|
|
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2), |
|
|
| nn.Conv2d(128, 256, 3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| ) |
| self.classifier = nn.Sequential( |
| nn.Flatten(), |
| nn.Dropout(0.3), |
| nn.Linear(256, classes), |
| ) |
|
|
| def forward(self, x): |
| return self.classifier(self.features(x)) |
|
|
|
|
| def build_resnet(name: str, classes: int): |
| if name == "ResNet18": |
| model = models.resnet18(weights=None) |
| elif name == "ResNet34": |
| model = models.resnet34(weights=None) |
| elif name == "ResNet50": |
| model = models.resnet50(weights=None) |
| elif name == "ResNet101": |
| model = models.resnet101(weights=None) |
| elif name == "ResNet152": |
| model = models.resnet152(weights=None) |
| else: |
| raise ValueError(f"Unsupported model name: {name}") |
|
|
| model.fc = nn.Linear(model.fc.in_features, classes) |
| return model |
|
|
|
|
| def load_class_names(): |
| class_names_path = ARTIFACTS_DIR / "class_names.json" |
| if not class_names_path.exists(): |
| raise FileNotFoundError( |
| "artifacts/class_names.json was not found. Run the notebook first so it can export deployment artifacts." |
| ) |
|
|
| with open(class_names_path, "r", encoding="utf-8") as file: |
| class_names = json.load(file) |
|
|
| if not class_names or len(class_names) <= 1: |
| raise ValueError("class_names.json is empty or invalid.") |
|
|
| return class_names |
|
|
|
|
| def resolve_best_model_name(): |
| best_model_path = ARTIFACTS_DIR / "best_model_name.txt" |
| if best_model_path.exists(): |
| name = best_model_path.read_text(encoding="utf-8").strip() |
| if name in CHECKPOINT_PATHS and CHECKPOINT_PATHS[name].exists(): |
| return name |
|
|
| for candidate in ["ResNet152", "ResNet101", "ResNet50", "ResNet34", "ResNet18", "BaselineCNN"]: |
| if CHECKPOINT_PATHS[candidate].exists(): |
| return candidate |
|
|
| raise FileNotFoundError("No checkpoint files were found next to app.py.") |
|
|
|
|
| def load_model(best_model_name: str, num_classes: int): |
| if best_model_name == "BaselineCNN": |
| model = BaselineCNN(num_classes) |
| else: |
| model = build_resnet(best_model_name, num_classes) |
|
|
| state_dict = torch.load(CHECKPOINT_PATHS[best_model_name], map_location=DEVICE) |
| model.load_state_dict(state_dict) |
| model = model.to(DEVICE) |
| model.eval() |
| return model |
|
|
|
|
| class_names = load_class_names() |
| best_model_name = resolve_best_model_name() |
| model = load_model(best_model_name, len(class_names)) |
|
|
|
|
| def predict_pil_image(image, top_k=5): |
| if image is None: |
| return pd.DataFrame(columns=["Class", "Probability"]) |
|
|
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| image_tensor = EVAL_TRANSFORM(image).unsqueeze(0).to(DEVICE) |
|
|
| with torch.no_grad(): |
| logits = model(image_tensor) |
| probabilities = torch.softmax(logits, dim=1).squeeze(0) |
|
|
| top_k = min(top_k, len(class_names)) |
| top_probs, top_indices = torch.topk(probabilities, k=top_k) |
|
|
| rows = [] |
| for idx, prob in zip(top_indices.tolist(), top_probs.tolist()): |
| rows.append({ |
| "Class": class_names[idx], |
| "Probability": float(prob), |
| }) |
|
|
| return pd.DataFrame(rows) |
|
|
|
|
| LISTING_TABLE = None |
|
|
|
|
| def normalize_text(text): |
| text = "" if text is None else str(text).lower().strip() |
| text = "".join(ch if ch.isalnum() else " " for ch in text) |
| return " ".join(text.split()) |
|
|
|
|
| def load_listing_dataset(): |
| global LISTING_TABLE |
| if LISTING_TABLE is not None: |
| return LISTING_TABLE |
|
|
| try: |
| frame = load_dataset("rebrowser/carguruscom-dataset", "car-listings", split="train").to_pandas() |
| except Exception: |
| LISTING_TABLE = pd.DataFrame() |
| return LISTING_TABLE |
|
|
| keep_columns = [ |
| "year", "make", "model", "trim", "bodyStyle", "price", "mileage", |
| "transmission", "drivetrain", "fuelType", "dealRatingKey", |
| "sellerCity", "sellerState", "listingUrl", "description" |
| ] |
| keep_columns = [column for column in keep_columns if column in frame.columns] |
| frame = frame[keep_columns].copy() |
|
|
| for column in ["year", "price", "mileage"]: |
| if column in frame.columns: |
| frame[column] = pd.to_numeric(frame[column], errors="coerce") |
|
|
| for column in [ |
| "make", "model", "trim", "bodyStyle", "transmission", |
| "drivetrain", "fuelType", "dealRatingKey", |
| "sellerCity", "sellerState", "listingUrl", "description" |
| ]: |
| if column in frame.columns: |
| frame[column] = frame[column].fillna("").astype(str) |
|
|
| frame["make_norm"] = frame.get("make", pd.Series(index=frame.index, dtype=str)).apply(normalize_text) |
| frame["model_norm"] = frame.get("model", pd.Series(index=frame.index, dtype=str)).apply(normalize_text) |
|
|
| LISTING_TABLE = frame |
| return LISTING_TABLE |
|
|
|
|
| def parse_predicted_car_label(class_name: str): |
| frame = load_listing_dataset() |
| year_match = re.search(r"(19\d{2}|20\d{2})", class_name) |
| year = int(year_match.group(1)) if year_match else None |
|
|
| if frame.empty: |
| return {"year": year, "make": None, "model": None, "body_style": None} |
|
|
| label = normalize_text(class_name) |
| matched_make = None |
| matched_model = None |
| body_style = None |
|
|
| makes = sorted( |
| [value for value in frame["make"].dropna().unique() if str(value).strip()], |
| key=lambda value: len(str(value)), |
| reverse=True |
| ) |
| for value in makes: |
| if normalize_text(value) in label: |
| matched_make = value |
| break |
|
|
| if matched_make is not None: |
| models_for_make = frame[frame["make"] == matched_make]["model"].dropna().unique().tolist() |
| models_for_make = sorted( |
| [value for value in models_for_make if str(value).strip()], |
| key=lambda value: len(str(value)), |
| reverse=True |
| ) |
| for value in models_for_make: |
| if normalize_text(value) in label: |
| matched_model = value |
| break |
|
|
| for value in ["sedan", "coupe", "convertible", "suv", "wagon", "hatchback", "minivan", "van", "pickup", "truck"]: |
| if value in label: |
| body_style = value |
| break |
|
|
| return {"year": year, "make": matched_make, "model": matched_model, "body_style": body_style} |
|
|
|
|
| def find_matching_listings(make=None, model=None, year=None, body_style=None, max_results=12): |
| frame = load_listing_dataset() |
| if frame.empty: |
| return pd.DataFrame() |
|
|
| filtered = frame.copy() |
|
|
| if make: |
| filtered = filtered[filtered["make_norm"] == normalize_text(make)] |
|
|
| if model: |
| model_norm = normalize_text(model) |
| exact_match = filtered[filtered["model_norm"] == model_norm] |
| if len(exact_match) > 0: |
| filtered = exact_match |
| else: |
| filtered = filtered[filtered["model_norm"].str.contains(model_norm, na=False)] |
|
|
| if year is not None and "year" in filtered.columns: |
| filtered = filtered[filtered["year"].between(year - 1, year + 1, inclusive="both")] |
|
|
| if body_style and "bodyStyle" in filtered.columns: |
| filtered = filtered[filtered["bodyStyle"].str.contains(body_style, case=False, na=False)] |
|
|
| if len(filtered) == 0 and make: |
| filtered = frame[frame["make_norm"] == normalize_text(make)].copy() |
|
|
| if len(filtered) == 0: |
| return filtered |
|
|
| filtered = filtered.copy() |
| filtered["year_distance"] = 0 if year is None else (filtered["year"] - year).abs() |
| filtered["deal_rank"] = filtered.get("dealRatingKey", pd.Series("NA", index=filtered.index)).map({ |
| "GREAT_PRICE": 0, |
| "GOOD_PRICE": 1, |
| "FAIR_PRICE": 2, |
| "POOR_PRICE": 3, |
| "OVERPRICED": 4, |
| "OUTLIER": 5, |
| "NA": 6, |
| }).fillna(6) |
|
|
| filtered = filtered.sort_values( |
| ["year_distance", "deal_rank", "price", "mileage"], |
| ascending=[True, True, True, True] |
| ) |
|
|
| show_columns = [ |
| column for column in [ |
| "year", "make", "model", "trim", "bodyStyle", "price", "mileage", |
| "transmission", "drivetrain", "fuelType", "dealRatingKey", |
| "sellerCity", "sellerState", "listingUrl" |
| ] if column in filtered.columns |
| ] |
|
|
| return filtered[show_columns].head(max_results).reset_index(drop=True) |
|
|
|
|
| def build_listing_summary(frame, parsed_car): |
| if frame is None or len(frame) == 0: |
| return "No matching marketplace listings were found." |
|
|
| lines = [f"Matched listings: {len(frame)}"] |
|
|
| if parsed_car.get("make"): |
| lines.append(f"Make: {parsed_car['make']}") |
| if parsed_car.get("model"): |
| lines.append(f"Model: {parsed_car['model']}") |
| if parsed_car.get("year"): |
| lines.append(f"Target year: {parsed_car['year']}") |
| if "price" in frame.columns and frame["price"].notna().any(): |
| lines.append(f"Price range: ${int(frame['price'].min()):,} - ${int(frame['price'].max()):,}") |
| if "mileage" in frame.columns and frame["mileage"].notna().any(): |
| lines.append(f"Mileage range: {int(frame['mileage'].min()):,} - {int(frame['mileage'].max()):,} miles") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def format_listing_table(frame): |
| if frame is None or len(frame) == 0: |
| return pd.DataFrame( |
| columns=[ |
| "Year", |
| "Make", |
| "Model", |
| "Trim", |
| "Body Style", |
| "Price", |
| "Mileage", |
| "Transmission", |
| "Drivetrain", |
| "Fuel Type", |
| "Deal Rating", |
| "City", |
| "State", |
| "Listing URL", |
| ] |
| ) |
|
|
| frame = frame.copy().rename( |
| columns={ |
| "year": "Year", |
| "make": "Make", |
| "model": "Model", |
| "trim": "Trim", |
| "bodyStyle": "Body Style", |
| "price": "Price", |
| "mileage": "Mileage", |
| "transmission": "Transmission", |
| "drivetrain": "Drivetrain", |
| "fuelType": "Fuel Type", |
| "dealRatingKey": "Deal Rating", |
| "sellerCity": "City", |
| "sellerState": "State", |
| "listingUrl": "Listing URL", |
| } |
| ) |
|
|
| if "Price" in frame.columns: |
| frame["Price"] = frame["Price"].apply( |
| lambda value: "—" if pd.isna(value) else f"${int(value):,}" |
| ) |
|
|
| if "Mileage" in frame.columns: |
| frame["Mileage"] = frame["Mileage"].apply( |
| lambda value: "—" if pd.isna(value) else f"{int(value):,} mi" |
| ) |
|
|
| order = [ |
| "Year", |
| "Make", |
| "Model", |
| "Trim", |
| "Body Style", |
| "Price", |
| "Mileage", |
| "Transmission", |
| "Drivetrain", |
| "Fuel Type", |
| "Deal Rating", |
| "City", |
| "State", |
| "Listing URL", |
| ] |
| order = [column for column in order if column in frame.columns] |
|
|
| return frame[order].reset_index(drop=True) |
|
|
|
|
| def run_demo(image): |
| if image is None: |
| return ( |
| "Please upload a car image.", |
| pd.DataFrame(), |
| "Marketplace summary will appear here.", |
| format_listing_table(pd.DataFrame()), |
| ) |
|
|
| predictions = predict_pil_image(image) |
| parsed_car = parse_predicted_car_label(predictions.iloc[0]["Class"]) |
| listings = find_matching_listings( |
| make=parsed_car.get("make"), |
| model=parsed_car.get("model"), |
| year=parsed_car.get("year"), |
| body_style=parsed_car.get("body_style"), |
| max_results=12, |
| ) |
|
|
| summary = ( |
| f"Best model: {best_model_name}\n" |
| f"Top prediction: {predictions.iloc[0]['Class']}\n" |
| f"Confidence: {predictions.iloc[0]['Probability']:.4f}" |
| ) |
|
|
| listing_summary = build_listing_summary(listings, parsed_car) |
| return summary, predictions, listing_summary, format_listing_table(listings) |
|
|
|
|
| simple_css = """ |
| :root { |
| --sc-accent: #C0504D; |
| --sc-accent-dark: #A2413F; |
| --sc-accent-soft: #F6E7E6; |
| --sc-white: #FFFFFF; |
| --sc-bg: #F4F5F7; |
| --sc-surface: #FAFAFB; |
| --sc-graphite: #2C2C31; |
| --sc-graphite-2: #3A3A40; |
| --sc-border: #D9DDE2; |
| --sc-muted: #6D7278; |
| --sc-text: #202327; |
| } |
| .gradio-container { |
| background: |
| linear-gradient(180deg, var(--sc-graphite) 0 118px, var(--sc-bg) 118px 100%); |
| font-family: Arial, Helvetica, sans-serif !important; |
| color: var(--sc-text); |
| } |
| #page { |
| max-width: 1220px; |
| margin: 0 auto; |
| padding: 24px 18px 40px 18px; |
| } |
| #topbar { |
| display: flex; |
| align-items: center; |
| justify-content: space-between; |
| gap: 24px; |
| color: white; |
| margin-bottom: 20px; |
| } |
| #brand { |
| display: flex; |
| flex-direction: column; |
| gap: 4px; |
| } |
| #brand h1 { |
| margin: 0; |
| font-size: 24px; |
| line-height: 1; |
| letter-spacing: 0.01em; |
| font-style: italic; |
| text-transform: uppercase; |
| font-weight: 900; |
| color: white; |
| } |
| #brand h1 span { |
| color: var(--sc-accent); |
| } |
| #brand p { |
| margin: 0; |
| font-size: 12px; |
| letter-spacing: 0.04em; |
| color: #C9CDD2; |
| text-transform: uppercase; |
| } |
| #hero { |
| background: |
| linear-gradient(135deg, rgba(255,255,255,0.06), rgba(255,255,255,0.02)), |
| linear-gradient(180deg, var(--sc-graphite-2), var(--sc-graphite)); |
| color: white; |
| padding: 22px 26px; |
| margin-bottom: 18px; |
| border-left: 4px solid var(--sc-accent); |
| box-shadow: 0 12px 30px rgba(20, 20, 24, 0.15); |
| } |
| #hero p { |
| margin: 0; |
| max-width: 860px; |
| font-size: 15px; |
| color: #D5D8DC; |
| line-height: 1.55; |
| } |
| .panel { |
| background: var(--sc-white); |
| border: 1px solid var(--sc-border); |
| box-shadow: 0 10px 24px rgba(32, 35, 39, 0.06); |
| padding: 14px; |
| } |
| .section-label { |
| margin: 0 0 8px 0; |
| font-size: 12px; |
| text-transform: uppercase; |
| letter-spacing: 0.08em; |
| color: var(--sc-muted); |
| font-weight: 700; |
| } |
| .section-label-tight { |
| margin: 0 0 4px 0; |
| font-size: 12px; |
| text-transform: uppercase; |
| letter-spacing: 0.08em; |
| color: var(--sc-muted); |
| font-weight: 700; |
| } |
| .highlight-card { |
| background: linear-gradient(180deg, #34363B, #25272B); |
| border: 1px solid #43464D; |
| color: white; |
| padding: 12px; |
| } |
| .highlight-card textarea, |
| .highlight-card input { |
| background: transparent !important; |
| color: white !important; |
| } |
| button.primary { |
| background: var(--sc-accent) !important; |
| color: white !important; |
| border: 1px solid var(--sc-accent-dark) !important; |
| border-radius: 0 !important; |
| font-weight: 700 !important; |
| text-transform: uppercase; |
| letter-spacing: 0.04em; |
| box-shadow: none !important; |
| } |
| button.primary:hover { |
| background: #AA4542 !important; |
| } |
| .gradio-container button.secondary, |
| .gradio-container .block, |
| .gradio-container .gr-box, |
| .gradio-image, |
| .gradio-dataframe, |
| textarea, |
| input { |
| border-radius: 0 !important; |
| } |
| .gradio-image { |
| border: 1px solid var(--sc-border) !important; |
| background: white !important; |
| } |
| .gradio-dataframe table thead tr th { |
| background: var(--sc-graphite) !important; |
| color: white !important; |
| border-color: var(--sc-graphite) !important; |
| font-weight: 700 !important; |
| } |
| .gradio-dataframe table tbody tr:nth-child(even) td { |
| background: var(--sc-surface) !important; |
| } |
| .gradio-dataframe table tbody tr td { |
| border-color: var(--sc-border) !important; |
| } |
| .gradio-container .wrap.svelte-1ipelgc, |
| .gradio-container .contain { |
| background: transparent !important; |
| } |
| #results-grid { |
| gap: 18px; |
| align-items: start; |
| } |
| .compact-box { |
| margin-top: 0 !important; |
| padding-top: 0 !important; |
| } |
| @media (max-width: 900px) { |
| #topbar { |
| flex-direction: column; |
| align-items: flex-start; |
| } |
| } |
| """ |
|
|
|
|
| with gr.Blocks(css=simple_css) as demo: |
| with gr.Column(elem_id="page"): |
| gr.HTML( |
| """ |
| <div id="topbar"> |
| <div id="brand"> |
| <h1>Stanford<span>Cars</span></h1> |
| <p>Image Classification Capstone Project</p> |
| </div> |
| </div> |
| <div id="hero"> |
| <p> |
| Upload an image from a computer and get a prediction view with model confidence and matched marketplace listings. |
| </p> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(elem_id="results-grid"): |
| with gr.Column(scale=5): |
| with gr.Column(elem_classes=["panel", "highlight-card"]): |
| gr.HTML('<div class="section-label" style="color:#CFD3D8;">Image Upload</div>') |
| image_input = gr.Image( |
| type="pil", |
| sources=["upload"], |
| label="Upload a car image from your computer", |
| height=360, |
| ) |
| predict_button = gr.Button("Predict", variant="primary") |
|
|
| with gr.Column(scale=7): |
| with gr.Column(elem_classes=["panel", "highlight-card"]): |
| gr.HTML('<div class="section-label" style="color:#CFD3D8;">Prediction Summary</div>') |
| summary_output = gr.Textbox(label="", show_label=False, lines=3) |
|
|
| with gr.Column(elem_classes=["panel", "highlight-card"]): |
| gr.HTML('<div class="section-label" style="color:#CFD3D8;">Marketplace Summary</div>') |
| listing_summary_output = gr.Textbox(label="", show_label=False, lines=5) |
|
|
| with gr.Column(elem_classes=["panel", "highlight-card"]): |
| gr.HTML('<div class="section-label-tight" style="color:#CFD3D8;">Top Predictions</div>') |
| predictions_output = gr.Dataframe(label="", show_label=False, interactive=False) |
|
|
| with gr.Column(elem_classes=["panel", "highlight-card"]): |
| gr.HTML('<div class="section-label" style="color:#CFD3D8;">Matched Marketplace Listings</div>') |
| listings_output = gr.Dataframe( |
| label="", |
| show_label=False, |
| interactive=False, |
| wrap=True, |
| ) |
|
|
| predict_button.click( |
| fn=run_demo, |
| inputs=image_input, |
| outputs=[ |
| summary_output, |
| predictions_output, |
| listing_summary_output, |
| listings_output, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |