import difflib import re from typing import Dict, Optional, Tuple import gradio as gr import torch import pandas as pd from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification MODEL_ID = "valentinocc/dog-breed-classifier" AKC_CSV_PATH = "akc-data-latest.csv" DOG_LABELS_PATH = "dogmodelbreedlist.json" # -----------Load model + processor----------------------- processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForImageClassification.from_pretrained(MODEL_ID) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ---------------Data Cleaning Helpers-------------------------------- def _normalize_name(s: str) -> str: #Lowercase, strip non-alphanumerics, collapse spaces s = s.lower().strip() s = re.sub(r"[^a-z0-9\s]", " ", s) s = re.sub(r"\s+", " ", s) return s def _load_akc_table(path: str) -> Tuple[pd.DataFrame, Dict[str, int]]: #Load AKC CSV + build a name->row_index map using a normalized breed name. df = pd.read_csv(path) name_col = "Unnamed: 0" if name_col not in df.columns: for c in df.columns: if df[c].dtype == "object": name_col = c break # Make a clean 'breed' column for display and mapping df = df.rename(columns={name_col: "breed"}) df["breed"] = df["breed"].astype(str) # Build normalized name -> row index map index_map: Dict[str, int] = {} for idx, name in enumerate(df["breed"].tolist()): index_map[_normalize_name(name)] = idx return df, index_map akc_df, akc_name_to_idx = _load_akc_table(AKC_CSV_PATH) # ------------------Alias rules---------------------------------------- # 1) Direct alias corrections. This fixes issues when pairing predictions with corresponding dataset column names (normalized -> AKC display name) ALIAS_DIRECT: Dict[str, str] = { # Poodles (AKC rows are usually written with parentheses) "standard poodle": "Poodle (Standard)", "miniature poodle": "Poodle (Miniature)", "toy poodle": "Poodle (Toy)", # Dachshund sizes often appear both ways in the wild "miniature dachshund": "Dachshund (Miniature)", "standard dachshund": "Dachshund (Standard)", # present in some AKC tables # Bull Terrier miniature vs. base "miniature bull terrier": "Bull Terrier (Miniature)", # American Eskimo Dog varieties "toy american eskimo": "American Eskimo Dog (Toy)", "miniature american eskimo": "American Eskimo Dog (Miniature)", "standard american eskimo": "American Eskimo Dog (Standard)", "toy american eskimo dog": "American Eskimo Dog (Toy)", "miniature american eskimo dog": "American Eskimo Dog (Miniature)", "standard american eskimo dog": "American Eskimo Dog (Standard)", # Others "eskimo dog": "American Eskimo Dog", "wire haired fox terrier": "Fox Terrier (Wire)", "smooth fox terrier": "Fox Terrier (Smooth)", "black and tan coonhound": "Black and Tan Coonhound", "german short haired pointer": "German Shorthaired Pointer", "german long haired pointer": "German Longhaired Pointer", "curly coated retriever": "Curly-Coated Retriever", "flat coated retriever": "Flat-Coated Retriever", "yorkshire terrier": "Yorkshire Terrier", "welsh springer spaniel": "Welsh Springer Spaniel", "english springer": "English Springer Spaniel", } # 2) Generic flip: " " -> " ()" # We only attempt the flip and accept it if it exists in the AKC index. SIZE_VARIANTS = {"standard", "miniature", "toy", "giant"} def _try_alias_then_flip(norm_label: str) -> Optional[pd.Series]: # Resolve aliases for common size naming and try a safe 'flip' if needed.""" # a) direct alias table if norm_label in ALIAS_DIRECT: target = _normalize_name(ALIAS_DIRECT[norm_label]) idx = akc_name_to_idx.get(target) if idx is not None: return akc_df.iloc[idx] # b) generic flip: " " -> " ()" IF that exists in AKC parts = norm_label.split(" ", 1) if len(parts) == 2: first, rest = parts[0], parts[1] if first in SIZE_VARIANTS: flipped_display = f"{rest.title()} ({first.title()})" flipped_norm = _normalize_name(flipped_display) idx = akc_name_to_idx.get(flipped_norm) if idx is not None: return akc_df.iloc[idx] return None # ------------------Lookup in AKC table------------------------------------------ def _lookup_breed_info(pred_label: str) -> Optional[pd.Series]: """ Find the best matching AKC row for a model label. 1) Direct normalized match 2) Alias resolution and safe variant flip ('Standard Poodle' -> 'Poodle (Standard)') 3) Simple stripped variants (remove trailing 'dog', 'terrier', 'hound') 4) Fuzzy match via difflib """ norm = _normalize_name(pred_label) # 1) direct match idx = akc_name_to_idx.get(norm) if idx is not None: return akc_df.iloc[idx] # 2) alias + safe flip row = _try_alias_then_flip(norm) if row is not None: return row # 3) simple stripped variants stripped_variants = { norm, re.sub(r"\bdog\b$", "", norm).strip(), re.sub(r"\bterrier\b$", "", norm).strip(), re.sub(r"\bhound\b$", "", norm).strip(), } for v in stripped_variants: if v in akc_name_to_idx: return akc_df.iloc[akc_name_to_idx[v]] # 4) fuzzy match candidates = difflib.get_close_matches(norm, akc_name_to_idx.keys(), n=1, cutoff=0.75) if candidates: return akc_df.iloc[akc_name_to_idx[candidates[0]]] return None def _format_breed_info(row: pd.Series) -> str: #Turn a single AKC row into a readable markdown snippet. def get(col, fallback="—"): return row[col] if col in row and pd.notna(row[col]) else fallback lines = [] lines.append(f"### {get('breed', 'Unknown Breed')}") if pd.notna(get('description')): lines.append(f"{get('description')}\n") # Facts block facts = [] if pd.notna(get('group')): facts.append(f"**Group:** {get('group')}") if pd.notna(get('temperament')): facts.append(f"**Temperament:** {get('temperament')}") # Height (inches) hmin, hmax = get('min_height'), get('max_height') if pd.notna(hmin) or pd.notna(hmax): facts.append(f"**Height:** {hmin if pd.notna(hmin) else '—'}–{hmax if pd.notna(hmax) else '—'} in") # Weight (pounds) wmin, wmax = get('min_weight'), get('max_weight') if pd.notna(wmin) or pd.notna(wmax): facts.append(f"**Weight:** {wmin if pd.notna(wmin) else '—'}–{wmax if pd.notna(wmax) else '—'} lb") # Life Expectancy (years) emin, emax = get('min_expectancy'), get('max_expectancy') if pd.notna(emin) or pd.notna(emax): facts.append(f"**Life Expectancy:** {emin if pd.notna(emin) else '—'}–{emax if pd.notna(emax) else '—'} yrs") if facts: lines.append("\n".join(facts)) # Optional traits if present in our AKC Dataset trait_fields = [ ("grooming_frequency_category", "Grooming"), ("shedding_category", "Shedding"), ("energy_level_category", "Energy Level"), ("trainability_category", "Trainability"), ("demeanor_category", "Demeanor"), ] traits = [] for col, label in trait_fields: val = get(col) if pd.notna(val): traits.append(f"- **{label}:** {val}") if traits: lines.append("\n**Traits**") lines.extend(traits) return "\n\n".join(lines).strip() # ---------------------Inference function---------------- # Accepts a PIL image and returns: # - Top-1 predicted breed with confidence # - A markdown block of AKC info for that breed (if found) def predict_with_info(img: Image.Image) -> str: inputs = processor(images=img, return_tensors="pt").to(device) with torch.inference_mode(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1) top_id = int(torch.argmax(probs, dim=-1).item()) top_prob = float(probs[0, top_id].item()) label = model.config.id2label.get(top_id, "Unknown") header = f"**Prediction:** {label} ({top_prob:.2%})" row = _lookup_breed_info(label) if row is None: return header + "\n\n_No matching breed found in AKC dataset._" info_md = _format_breed_info(row) return header + "\n\n" + info_md # -------------------- UI ------------------------------------- demo = gr.Interface( fn=predict_with_info, inputs=gr.Image(type="pil", label="Upload a dog photo"), outputs=gr.Markdown(label="Prediction + Breed Info"), title="Dog Breed Classifier + AKC Info", description=( f"Upload an image of a dog. The app predicts the breed using '{MODEL_ID}' " "and shows breed details from the American Kennel Club dataset. Dataset: https://github.com/tmfilho/akcdata/blob/master/data/akc-data-latest.csv" ), allow_flagging="never", ) if __name__ == "__main__": demo.launch()