Spaces:
Sleeping
Sleeping
| import difflib | |
| import re | |
| from typing import Dict, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| MODEL_ID = "valentinocc/dog-breed-classifier" | |
| AKC_CSV_PATH = "akc-data-latest.csv" | |
| DOG_LABELS_PATH = "dogmodelbreedlist.json" | |
| # -----------Load model + processor----------------------- | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # ---------------Data Cleaning Helpers-------------------------------- | |
| def _normalize_name(s: str) -> str: | |
| #Lowercase, strip non-alphanumerics, collapse spaces | |
| s = s.lower().strip() | |
| s = re.sub(r"[^a-z0-9\s]", " ", s) | |
| s = re.sub(r"\s+", " ", s) | |
| return s | |
| def _load_akc_table(path: str) -> Tuple[pd.DataFrame, Dict[str, int]]: | |
| #Load AKC CSV + build a name->row_index map using a normalized breed name. | |
| df = pd.read_csv(path) | |
| name_col = "Unnamed: 0" | |
| if name_col not in df.columns: | |
| for c in df.columns: | |
| if df[c].dtype == "object": | |
| name_col = c | |
| break | |
| # Make a clean 'breed' column for display and mapping | |
| df = df.rename(columns={name_col: "breed"}) | |
| df["breed"] = df["breed"].astype(str) | |
| # Build normalized name -> row index map | |
| index_map: Dict[str, int] = {} | |
| for idx, name in enumerate(df["breed"].tolist()): | |
| index_map[_normalize_name(name)] = idx | |
| return df, index_map | |
| akc_df, akc_name_to_idx = _load_akc_table(AKC_CSV_PATH) | |
| # ------------------Alias rules---------------------------------------- | |
| # 1) Direct alias corrections. This fixes issues when pairing predictions with corresponding dataset column names (normalized -> AKC display name) | |
| ALIAS_DIRECT: Dict[str, str] = { | |
| # Poodles (AKC rows are usually written with parentheses) | |
| "standard poodle": "Poodle (Standard)", | |
| "miniature poodle": "Poodle (Miniature)", | |
| "toy poodle": "Poodle (Toy)", | |
| # Dachshund sizes often appear both ways in the wild | |
| "miniature dachshund": "Dachshund (Miniature)", | |
| "standard dachshund": "Dachshund (Standard)", # present in some AKC tables | |
| # Bull Terrier miniature vs. base | |
| "miniature bull terrier": "Bull Terrier (Miniature)", | |
| # American Eskimo Dog varieties | |
| "toy american eskimo": "American Eskimo Dog (Toy)", | |
| "miniature american eskimo": "American Eskimo Dog (Miniature)", | |
| "standard american eskimo": "American Eskimo Dog (Standard)", | |
| "toy american eskimo dog": "American Eskimo Dog (Toy)", | |
| "miniature american eskimo dog": "American Eskimo Dog (Miniature)", | |
| "standard american eskimo dog": "American Eskimo Dog (Standard)", | |
| # Others | |
| "eskimo dog": "American Eskimo Dog", | |
| "wire haired fox terrier": "Fox Terrier (Wire)", | |
| "smooth fox terrier": "Fox Terrier (Smooth)", | |
| "black and tan coonhound": "Black and Tan Coonhound", | |
| "german short haired pointer": "German Shorthaired Pointer", | |
| "german long haired pointer": "German Longhaired Pointer", | |
| "curly coated retriever": "Curly-Coated Retriever", | |
| "flat coated retriever": "Flat-Coated Retriever", | |
| "yorkshire terrier": "Yorkshire Terrier", | |
| "welsh springer spaniel": "Welsh Springer Spaniel", | |
| "english springer": "English Springer Spaniel", | |
| } | |
| # 2) Generic flip: "<Variant> <Base>" -> "<Base> (<Variant>)" | |
| # We only attempt the flip and accept it if it exists in the AKC index. | |
| SIZE_VARIANTS = {"standard", "miniature", "toy", "giant"} | |
| def _try_alias_then_flip(norm_label: str) -> Optional[pd.Series]: | |
| # Resolve aliases for common size naming and try a safe 'flip' if needed.""" | |
| # a) direct alias table | |
| if norm_label in ALIAS_DIRECT: | |
| target = _normalize_name(ALIAS_DIRECT[norm_label]) | |
| idx = akc_name_to_idx.get(target) | |
| if idx is not None: | |
| return akc_df.iloc[idx] | |
| # b) generic flip: "<variant> <rest>" -> "<rest> (<variant>)" IF that exists in AKC | |
| parts = norm_label.split(" ", 1) | |
| if len(parts) == 2: | |
| first, rest = parts[0], parts[1] | |
| if first in SIZE_VARIANTS: | |
| flipped_display = f"{rest.title()} ({first.title()})" | |
| flipped_norm = _normalize_name(flipped_display) | |
| idx = akc_name_to_idx.get(flipped_norm) | |
| if idx is not None: | |
| return akc_df.iloc[idx] | |
| return None | |
| # ------------------Lookup in AKC table------------------------------------------ | |
| def _lookup_breed_info(pred_label: str) -> Optional[pd.Series]: | |
| """ | |
| Find the best matching AKC row for a model label. | |
| 1) Direct normalized match | |
| 2) Alias resolution and safe variant flip ('Standard Poodle' -> 'Poodle (Standard)') | |
| 3) Simple stripped variants (remove trailing 'dog', 'terrier', 'hound') | |
| 4) Fuzzy match via difflib | |
| """ | |
| norm = _normalize_name(pred_label) | |
| # 1) direct match | |
| idx = akc_name_to_idx.get(norm) | |
| if idx is not None: | |
| return akc_df.iloc[idx] | |
| # 2) alias + safe flip | |
| row = _try_alias_then_flip(norm) | |
| if row is not None: | |
| return row | |
| # 3) simple stripped variants | |
| stripped_variants = { | |
| norm, | |
| re.sub(r"\bdog\b$", "", norm).strip(), | |
| re.sub(r"\bterrier\b$", "", norm).strip(), | |
| re.sub(r"\bhound\b$", "", norm).strip(), | |
| } | |
| for v in stripped_variants: | |
| if v in akc_name_to_idx: | |
| return akc_df.iloc[akc_name_to_idx[v]] | |
| # 4) fuzzy match | |
| candidates = difflib.get_close_matches(norm, akc_name_to_idx.keys(), n=1, cutoff=0.75) | |
| if candidates: | |
| return akc_df.iloc[akc_name_to_idx[candidates[0]]] | |
| return None | |
| def _format_breed_info(row: pd.Series) -> str: | |
| #Turn a single AKC row into a readable markdown snippet. | |
| def get(col, fallback="β"): | |
| return row[col] if col in row and pd.notna(row[col]) else fallback | |
| lines = [] | |
| lines.append(f"### {get('breed', 'Unknown Breed')}") | |
| if pd.notna(get('description')): | |
| lines.append(f"{get('description')}\n") | |
| # Facts block | |
| facts = [] | |
| if pd.notna(get('group')): | |
| facts.append(f"**Group:** {get('group')}") | |
| if pd.notna(get('temperament')): | |
| facts.append(f"**Temperament:** {get('temperament')}") | |
| # Height (inches) | |
| hmin, hmax = get('min_height'), get('max_height') | |
| if pd.notna(hmin) or pd.notna(hmax): | |
| facts.append(f"**Height:** {hmin if pd.notna(hmin) else 'β'}β{hmax if pd.notna(hmax) else 'β'} in") | |
| # Weight (pounds) | |
| wmin, wmax = get('min_weight'), get('max_weight') | |
| if pd.notna(wmin) or pd.notna(wmax): | |
| facts.append(f"**Weight:** {wmin if pd.notna(wmin) else 'β'}β{wmax if pd.notna(wmax) else 'β'} lb") | |
| # Life Expectancy (years) | |
| emin, emax = get('min_expectancy'), get('max_expectancy') | |
| if pd.notna(emin) or pd.notna(emax): | |
| facts.append(f"**Life Expectancy:** {emin if pd.notna(emin) else 'β'}β{emax if pd.notna(emax) else 'β'} yrs") | |
| if facts: | |
| lines.append("\n".join(facts)) | |
| # Optional traits if present in our AKC Dataset | |
| trait_fields = [ | |
| ("grooming_frequency_category", "Grooming"), | |
| ("shedding_category", "Shedding"), | |
| ("energy_level_category", "Energy Level"), | |
| ("trainability_category", "Trainability"), | |
| ("demeanor_category", "Demeanor"), | |
| ] | |
| traits = [] | |
| for col, label in trait_fields: | |
| val = get(col) | |
| if pd.notna(val): | |
| traits.append(f"- **{label}:** {val}") | |
| if traits: | |
| lines.append("\n**Traits**") | |
| lines.extend(traits) | |
| return "\n\n".join(lines).strip() | |
| # ---------------------Inference function---------------- | |
| # Accepts a PIL image and returns: | |
| # - Top-1 predicted breed with confidence | |
| # - A markdown block of AKC info for that breed (if found) | |
| def predict_with_info(img: Image.Image) -> str: | |
| inputs = processor(images=img, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| top_id = int(torch.argmax(probs, dim=-1).item()) | |
| top_prob = float(probs[0, top_id].item()) | |
| label = model.config.id2label.get(top_id, "Unknown") | |
| header = f"**Prediction:** {label} ({top_prob:.2%})" | |
| row = _lookup_breed_info(label) | |
| if row is None: | |
| return header + "\n\n_No matching breed found in AKC dataset._" | |
| info_md = _format_breed_info(row) | |
| return header + "\n\n" + info_md | |
| # -------------------- UI ------------------------------------- | |
| demo = gr.Interface( | |
| fn=predict_with_info, | |
| inputs=gr.Image(type="pil", label="Upload a dog photo"), | |
| outputs=gr.Markdown(label="Prediction + Breed Info"), | |
| title="Dog Breed Classifier + AKC Info", | |
| description=( | |
| f"Upload an image of a dog. The app predicts the breed using '{MODEL_ID}' " | |
| "and shows breed details from the American Kennel Club dataset. Dataset: https://github.com/tmfilho/akcdata/blob/master/data/akc-data-latest.csv" | |
| ), | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |