DogBreedID / app.py
achase25's picture
Update app.py
30a588a verified
import difflib
import re
from typing import Dict, Optional, Tuple
import gradio as gr
import torch
import pandas as pd
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODEL_ID = "valentinocc/dog-breed-classifier"
AKC_CSV_PATH = "akc-data-latest.csv"
DOG_LABELS_PATH = "dogmodelbreedlist.json"
# -----------Load model + processor-----------------------
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# ---------------Data Cleaning Helpers--------------------------------
def _normalize_name(s: str) -> str:
#Lowercase, strip non-alphanumerics, collapse spaces
s = s.lower().strip()
s = re.sub(r"[^a-z0-9\s]", " ", s)
s = re.sub(r"\s+", " ", s)
return s
def _load_akc_table(path: str) -> Tuple[pd.DataFrame, Dict[str, int]]:
#Load AKC CSV + build a name->row_index map using a normalized breed name.
df = pd.read_csv(path)
name_col = "Unnamed: 0"
if name_col not in df.columns:
for c in df.columns:
if df[c].dtype == "object":
name_col = c
break
# Make a clean 'breed' column for display and mapping
df = df.rename(columns={name_col: "breed"})
df["breed"] = df["breed"].astype(str)
# Build normalized name -> row index map
index_map: Dict[str, int] = {}
for idx, name in enumerate(df["breed"].tolist()):
index_map[_normalize_name(name)] = idx
return df, index_map
akc_df, akc_name_to_idx = _load_akc_table(AKC_CSV_PATH)
# ------------------Alias rules----------------------------------------
# 1) Direct alias corrections. This fixes issues when pairing predictions with corresponding dataset column names (normalized -> AKC display name)
ALIAS_DIRECT: Dict[str, str] = {
# Poodles (AKC rows are usually written with parentheses)
"standard poodle": "Poodle (Standard)",
"miniature poodle": "Poodle (Miniature)",
"toy poodle": "Poodle (Toy)",
# Dachshund sizes often appear both ways in the wild
"miniature dachshund": "Dachshund (Miniature)",
"standard dachshund": "Dachshund (Standard)", # present in some AKC tables
# Bull Terrier miniature vs. base
"miniature bull terrier": "Bull Terrier (Miniature)",
# American Eskimo Dog varieties
"toy american eskimo": "American Eskimo Dog (Toy)",
"miniature american eskimo": "American Eskimo Dog (Miniature)",
"standard american eskimo": "American Eskimo Dog (Standard)",
"toy american eskimo dog": "American Eskimo Dog (Toy)",
"miniature american eskimo dog": "American Eskimo Dog (Miniature)",
"standard american eskimo dog": "American Eskimo Dog (Standard)",
# Others
"eskimo dog": "American Eskimo Dog",
"wire haired fox terrier": "Fox Terrier (Wire)",
"smooth fox terrier": "Fox Terrier (Smooth)",
"black and tan coonhound": "Black and Tan Coonhound",
"german short haired pointer": "German Shorthaired Pointer",
"german long haired pointer": "German Longhaired Pointer",
"curly coated retriever": "Curly-Coated Retriever",
"flat coated retriever": "Flat-Coated Retriever",
"yorkshire terrier": "Yorkshire Terrier",
"welsh springer spaniel": "Welsh Springer Spaniel",
"english springer": "English Springer Spaniel",
}
# 2) Generic flip: "<Variant> <Base>" -> "<Base> (<Variant>)"
# We only attempt the flip and accept it if it exists in the AKC index.
SIZE_VARIANTS = {"standard", "miniature", "toy", "giant"}
def _try_alias_then_flip(norm_label: str) -> Optional[pd.Series]:
# Resolve aliases for common size naming and try a safe 'flip' if needed."""
# a) direct alias table
if norm_label in ALIAS_DIRECT:
target = _normalize_name(ALIAS_DIRECT[norm_label])
idx = akc_name_to_idx.get(target)
if idx is not None:
return akc_df.iloc[idx]
# b) generic flip: "<variant> <rest>" -> "<rest> (<variant>)" IF that exists in AKC
parts = norm_label.split(" ", 1)
if len(parts) == 2:
first, rest = parts[0], parts[1]
if first in SIZE_VARIANTS:
flipped_display = f"{rest.title()} ({first.title()})"
flipped_norm = _normalize_name(flipped_display)
idx = akc_name_to_idx.get(flipped_norm)
if idx is not None:
return akc_df.iloc[idx]
return None
# ------------------Lookup in AKC table------------------------------------------
def _lookup_breed_info(pred_label: str) -> Optional[pd.Series]:
"""
Find the best matching AKC row for a model label.
1) Direct normalized match
2) Alias resolution and safe variant flip ('Standard Poodle' -> 'Poodle (Standard)')
3) Simple stripped variants (remove trailing 'dog', 'terrier', 'hound')
4) Fuzzy match via difflib
"""
norm = _normalize_name(pred_label)
# 1) direct match
idx = akc_name_to_idx.get(norm)
if idx is not None:
return akc_df.iloc[idx]
# 2) alias + safe flip
row = _try_alias_then_flip(norm)
if row is not None:
return row
# 3) simple stripped variants
stripped_variants = {
norm,
re.sub(r"\bdog\b$", "", norm).strip(),
re.sub(r"\bterrier\b$", "", norm).strip(),
re.sub(r"\bhound\b$", "", norm).strip(),
}
for v in stripped_variants:
if v in akc_name_to_idx:
return akc_df.iloc[akc_name_to_idx[v]]
# 4) fuzzy match
candidates = difflib.get_close_matches(norm, akc_name_to_idx.keys(), n=1, cutoff=0.75)
if candidates:
return akc_df.iloc[akc_name_to_idx[candidates[0]]]
return None
def _format_breed_info(row: pd.Series) -> str:
#Turn a single AKC row into a readable markdown snippet.
def get(col, fallback="β€”"):
return row[col] if col in row and pd.notna(row[col]) else fallback
lines = []
lines.append(f"### {get('breed', 'Unknown Breed')}")
if pd.notna(get('description')):
lines.append(f"{get('description')}\n")
# Facts block
facts = []
if pd.notna(get('group')):
facts.append(f"**Group:** {get('group')}")
if pd.notna(get('temperament')):
facts.append(f"**Temperament:** {get('temperament')}")
# Height (inches)
hmin, hmax = get('min_height'), get('max_height')
if pd.notna(hmin) or pd.notna(hmax):
facts.append(f"**Height:** {hmin if pd.notna(hmin) else 'β€”'}–{hmax if pd.notna(hmax) else 'β€”'} in")
# Weight (pounds)
wmin, wmax = get('min_weight'), get('max_weight')
if pd.notna(wmin) or pd.notna(wmax):
facts.append(f"**Weight:** {wmin if pd.notna(wmin) else 'β€”'}–{wmax if pd.notna(wmax) else 'β€”'} lb")
# Life Expectancy (years)
emin, emax = get('min_expectancy'), get('max_expectancy')
if pd.notna(emin) or pd.notna(emax):
facts.append(f"**Life Expectancy:** {emin if pd.notna(emin) else 'β€”'}–{emax if pd.notna(emax) else 'β€”'} yrs")
if facts:
lines.append("\n".join(facts))
# Optional traits if present in our AKC Dataset
trait_fields = [
("grooming_frequency_category", "Grooming"),
("shedding_category", "Shedding"),
("energy_level_category", "Energy Level"),
("trainability_category", "Trainability"),
("demeanor_category", "Demeanor"),
]
traits = []
for col, label in trait_fields:
val = get(col)
if pd.notna(val):
traits.append(f"- **{label}:** {val}")
if traits:
lines.append("\n**Traits**")
lines.extend(traits)
return "\n\n".join(lines).strip()
# ---------------------Inference function----------------
# Accepts a PIL image and returns:
# - Top-1 predicted breed with confidence
# - A markdown block of AKC info for that breed (if found)
def predict_with_info(img: Image.Image) -> str:
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.inference_mode():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
top_id = int(torch.argmax(probs, dim=-1).item())
top_prob = float(probs[0, top_id].item())
label = model.config.id2label.get(top_id, "Unknown")
header = f"**Prediction:** {label} ({top_prob:.2%})"
row = _lookup_breed_info(label)
if row is None:
return header + "\n\n_No matching breed found in AKC dataset._"
info_md = _format_breed_info(row)
return header + "\n\n" + info_md
# -------------------- UI -------------------------------------
demo = gr.Interface(
fn=predict_with_info,
inputs=gr.Image(type="pil", label="Upload a dog photo"),
outputs=gr.Markdown(label="Prediction + Breed Info"),
title="Dog Breed Classifier + AKC Info",
description=(
f"Upload an image of a dog. The app predicts the breed using '{MODEL_ID}' "
"and shows breed details from the American Kennel Club dataset. Dataset: https://github.com/tmfilho/akcdata/blob/master/data/akc-data-latest.csv"
),
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()