File size: 9,215 Bytes
af85b64
 
 
 
04c0574
 
af85b64
04c0574
 
 
 
4f6366e
30a588a
04c0574
4f6366e
 
04c0574
 
 
 
 
 
4f6366e
af85b64
4f6366e
af85b64
 
 
 
04c0574
af85b64
4f6366e
 
af85b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6366e
 
 
4ecb54a
9b5c980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a588a
 
 
 
 
 
 
 
 
 
 
 
 
9b5c980
 
 
 
 
 
 
4f6366e
9b5c980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6366e
 
af85b64
 
 
 
4f6366e
9b5c980
 
af85b64
 
 
4f6366e
9b5c980
 
 
 
4f6366e
9b5c980
 
 
af85b64
4f6366e
af85b64
 
 
 
 
 
 
 
 
04c0574
9b5c980
af85b64
 
 
 
 
 
 
4f6366e
af85b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6366e
af85b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6366e
4ecb54a
 
 
af85b64
4ecb54a
af85b64
04c0574
 
 
 
 
 
 
af85b64
 
 
 
 
04c0574
af85b64
 
04c0574
4ecb54a
 
 
04c0574
af85b64
04c0574
af85b64
 
04c0574
af85b64
4f6366e
04c0574
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import difflib
import re
from typing import Dict, Optional, Tuple

import gradio as gr
import torch
import pandas as pd
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification

MODEL_ID = "valentinocc/dog-breed-classifier"
AKC_CSV_PATH = "akc-data-latest.csv"  
DOG_LABELS_PATH = "dogmodelbreedlist.json"


# -----------Load model + processor-----------------------
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ---------------Data Cleaning Helpers--------------------------------
def _normalize_name(s: str) -> str:
    #Lowercase, strip non-alphanumerics, collapse spaces
    s = s.lower().strip()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\s+", " ", s)
    return s

def _load_akc_table(path: str) -> Tuple[pd.DataFrame, Dict[str, int]]:
    #Load AKC CSV + build a name->row_index map using a normalized breed name.

    df = pd.read_csv(path)
    name_col = "Unnamed: 0"
    if name_col not in df.columns:
        for c in df.columns:
            if df[c].dtype == "object":
                name_col = c
                break

    # Make a clean 'breed' column for display and mapping
    df = df.rename(columns={name_col: "breed"})
    df["breed"] = df["breed"].astype(str)

    # Build normalized name -> row index map
    index_map: Dict[str, int] = {}
    for idx, name in enumerate(df["breed"].tolist()):
        index_map[_normalize_name(name)] = idx

    return df, index_map

akc_df, akc_name_to_idx = _load_akc_table(AKC_CSV_PATH)


# ------------------Alias rules----------------------------------------

# 1) Direct alias corrections. This fixes issues when pairing predictions with corresponding dataset column names (normalized -> AKC display name)
ALIAS_DIRECT: Dict[str, str] = {
    # Poodles (AKC rows are usually written with parentheses)
    "standard poodle": "Poodle (Standard)",
    "miniature poodle": "Poodle (Miniature)",
    "toy poodle": "Poodle (Toy)",

    # Dachshund sizes often appear both ways in the wild
    "miniature dachshund": "Dachshund (Miniature)",
    "standard dachshund": "Dachshund (Standard)",  # present in some AKC tables

    # Bull Terrier miniature vs. base
    "miniature bull terrier": "Bull Terrier (Miniature)",

    # American Eskimo Dog varieties
    "toy american eskimo": "American Eskimo Dog (Toy)",
    "miniature american eskimo": "American Eskimo Dog (Miniature)",
    "standard american eskimo": "American Eskimo Dog (Standard)",
    "toy american eskimo dog": "American Eskimo Dog (Toy)",
    "miniature american eskimo dog": "American Eskimo Dog (Miniature)",
    "standard american eskimo dog": "American Eskimo Dog (Standard)",

    # Others
    "eskimo dog": "American Eskimo Dog",
    "wire haired fox terrier": "Fox Terrier (Wire)",
    "smooth fox terrier": "Fox Terrier (Smooth)",
    "black and tan coonhound": "Black and Tan Coonhound",
    "german short haired pointer": "German Shorthaired Pointer",
    "german long haired pointer": "German Longhaired Pointer",
    "curly coated retriever": "Curly-Coated Retriever",
    "flat coated retriever": "Flat-Coated Retriever",
    "yorkshire terrier": "Yorkshire Terrier",
    "welsh springer spaniel": "Welsh Springer Spaniel",
    "english springer": "English Springer Spaniel",
}

# 2) Generic flip:  "<Variant> <Base>"  ->  "<Base> (<Variant>)"
#    We only attempt the flip and accept it if it exists in the AKC index.
SIZE_VARIANTS = {"standard", "miniature", "toy", "giant"}

def _try_alias_then_flip(norm_label: str) -> Optional[pd.Series]:
#       Resolve aliases for common size naming and try a safe 'flip' if needed."""
    # a) direct alias table
    if norm_label in ALIAS_DIRECT:
        target = _normalize_name(ALIAS_DIRECT[norm_label])
        idx = akc_name_to_idx.get(target)
        if idx is not None:
            return akc_df.iloc[idx]

    # b) generic flip: "<variant> <rest>" -> "<rest> (<variant>)" IF that exists in AKC
    parts = norm_label.split(" ", 1)
    if len(parts) == 2:
        first, rest = parts[0], parts[1]
        if first in SIZE_VARIANTS:
            flipped_display = f"{rest.title()} ({first.title()})"
            flipped_norm = _normalize_name(flipped_display)
            idx = akc_name_to_idx.get(flipped_norm)
            if idx is not None:
                return akc_df.iloc[idx]

    return None

# ------------------Lookup in AKC table------------------------------------------

def _lookup_breed_info(pred_label: str) -> Optional[pd.Series]:
    """
    Find the best matching AKC row for a model label.
    1) Direct normalized match
    2) Alias resolution and safe variant flip ('Standard Poodle' -> 'Poodle (Standard)')
    3) Simple stripped variants (remove trailing 'dog', 'terrier', 'hound')
    4) Fuzzy match via difflib
    """
    norm = _normalize_name(pred_label)

# 1) direct match
    idx = akc_name_to_idx.get(norm)
    if idx is not None:
        return akc_df.iloc[idx]

# 2) alias + safe flip
    row = _try_alias_then_flip(norm)
    if row is not None:
        return row

# 3) simple stripped variants
    stripped_variants = {
        norm,
        re.sub(r"\bdog\b$", "", norm).strip(),
        re.sub(r"\bterrier\b$", "", norm).strip(),
        re.sub(r"\bhound\b$", "", norm).strip(),
    }
    for v in stripped_variants:
        if v in akc_name_to_idx:
            return akc_df.iloc[akc_name_to_idx[v]]

    # 4) fuzzy match
    candidates = difflib.get_close_matches(norm, akc_name_to_idx.keys(), n=1, cutoff=0.75)
    if candidates:
        return akc_df.iloc[akc_name_to_idx[candidates[0]]]

    return None

def _format_breed_info(row: pd.Series) -> str:
    #Turn a single AKC row into a readable markdown snippet.
    def get(col, fallback="β€”"):
        return row[col] if col in row and pd.notna(row[col]) else fallback

    lines = []
    lines.append(f"### {get('breed', 'Unknown Breed')}")
    if pd.notna(get('description')):
        lines.append(f"{get('description')}\n")

    # Facts block
    facts = []
    if pd.notna(get('group')):
        facts.append(f"**Group:** {get('group')}")
    if pd.notna(get('temperament')):
        facts.append(f"**Temperament:** {get('temperament')}")
    # Height (inches)
    hmin, hmax = get('min_height'), get('max_height')
    if pd.notna(hmin) or pd.notna(hmax):
        facts.append(f"**Height:** {hmin if pd.notna(hmin) else 'β€”'}–{hmax if pd.notna(hmax) else 'β€”'} in")
    # Weight (pounds)
    wmin, wmax = get('min_weight'), get('max_weight')
    if pd.notna(wmin) or pd.notna(wmax):
        facts.append(f"**Weight:** {wmin if pd.notna(wmin) else 'β€”'}–{wmax if pd.notna(wmax) else 'β€”'} lb")
    # Life Expectancy (years)
    emin, emax = get('min_expectancy'), get('max_expectancy')
    if pd.notna(emin) or pd.notna(emax):
        facts.append(f"**Life Expectancy:** {emin if pd.notna(emin) else 'β€”'}–{emax if pd.notna(emax) else 'β€”'} yrs")

    if facts:
        lines.append("\n".join(facts))

    # Optional traits if present in our AKC Dataset
    trait_fields = [
        ("grooming_frequency_category", "Grooming"),
        ("shedding_category", "Shedding"),
        ("energy_level_category", "Energy Level"),
        ("trainability_category", "Trainability"),
        ("demeanor_category", "Demeanor"),
    ]
    traits = []
    for col, label in trait_fields:
        val = get(col)
        if pd.notna(val):
            traits.append(f"- **{label}:** {val}")
    if traits:
        lines.append("\n**Traits**")
        lines.extend(traits)

    return "\n\n".join(lines).strip()

# ---------------------Inference function----------------
 #    Accepts a PIL image and returns:
 #    - Top-1 predicted breed with confidence
 #    - A markdown block of AKC info for that breed (if found)
def predict_with_info(img: Image.Image) -> str:
     
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.inference_mode():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        top_id = int(torch.argmax(probs, dim=-1).item())
        top_prob = float(probs[0, top_id].item())

    label = model.config.id2label.get(top_id, "Unknown")
    header = f"**Prediction:** {label} ({top_prob:.2%})"

    row = _lookup_breed_info(label)
    if row is None:
        return header + "\n\n_No matching breed found in AKC dataset._"

    info_md = _format_breed_info(row)
    return header + "\n\n" + info_md


# -------------------- UI -------------------------------------

demo = gr.Interface(
    fn=predict_with_info,
    inputs=gr.Image(type="pil", label="Upload a dog photo"),
    outputs=gr.Markdown(label="Prediction + Breed Info"),
    title="Dog Breed Classifier + AKC Info",
    description=(
        f"Upload an image of a dog. The app predicts the breed using '{MODEL_ID}' "
        "and shows breed details from the American Kennel Club dataset. Dataset: https://github.com/tmfilho/akcdata/blob/master/data/akc-data-latest.csv"
    ),
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()