Spaces:
Sleeping
Sleeping
| import torch | |
| import time | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import urllib.request | |
| import gradio as gr | |
| # ============================================================ | |
| # LOAD IMAGENET RESNET50 | |
| # ============================================================ | |
| resnet = models.resnet50(weights="IMAGENET1K_V2") | |
| resnet.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| labels_url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | |
| imagenet_labels = urllib.request.urlopen(labels_url).read().decode("utf-8").split("\n") | |
| dog_indices = list(range(151, 269)) | |
| cat_indices = [281, 282, 283, 284, 285] | |
| # ============================================================ | |
| # DOG BREED MODEL | |
| # ============================================================ | |
| dog_model_name = "prithivMLmods/Dog-Breed-120" | |
| dog_processor = AutoImageProcessor.from_pretrained(dog_model_name) | |
| dog_model = AutoModelForImageClassification.from_pretrained(dog_model_name) | |
| # ============================================================ | |
| # CAT BREED MODEL | |
| # ============================================================ | |
| cat_model_name = "dima806/67_cat_breeds_image_detection" | |
| cat_processor = AutoImageProcessor.from_pretrained(cat_model_name) | |
| cat_model = AutoModelForImageClassification.from_pretrained(cat_model_name) | |
| # ============================================================ | |
| # PIPELINE FUNCTIONS | |
| # ============================================================ | |
| def detect_animal_type(image): | |
| start = time.time() | |
| img = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = resnet(img) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| idx = probs.argmax().item() | |
| conf = float(probs[idx]) | |
| latency = (time.time() - start) | |
| if idx in dog_indices: | |
| return "dog", imagenet_labels[idx], conf, latency | |
| elif idx in cat_indices: | |
| return "cat", imagenet_labels[idx], conf, latency | |
| return "other", imagenet_labels[idx], conf, latency | |
| def predict_dog_breed(image): | |
| start = time.time() | |
| inputs = dog_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = dog_model(**inputs) | |
| probs = torch.softmax(out.logits, dim=1)[0] | |
| idx = probs.argmax().item() | |
| latency = (time.time() - start) | |
| return dog_model.config.id2label[idx], float(probs[idx]), latency | |
| def predict_cat_breed(image): | |
| start = time.time() | |
| inputs = cat_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = cat_model(**inputs) | |
| probs = torch.softmax(out.logits, dim=1)[0] | |
| idx = probs.argmax().item() | |
| latency = (time.time() - start) | |
| return cat_model.config.id2label[idx], float(probs[idx]), latency | |
| # ============================================================ | |
| # MAIN PIPELINE | |
| # ============================================================ | |
| def run_pipeline(input_image): | |
| if input_image is None: | |
| return "β", "β", "β", "β", "β", "β", "" | |
| total_start = time.time() | |
| image = input_image.convert("RGB") | |
| logs = [] | |
| # STEP 1 β SPECIES | |
| animal, base_label, base_conf, t1 = detect_animal_type(image) | |
| logs.append(f"[Species Detection] {animal.upper()} | {t1:.4f} s") | |
| # STEP 2 β BREED | |
| if animal == "dog": | |
| breed, conf, t2 = predict_dog_breed(image) | |
| logs.append(f"[Dog Breed Model] {breed} ({conf:.4f}) | {t2:.4f} s") | |
| total_latency = (time.time() - total_start) | |
| logs.append(f"[Total Pipeline Latency] {total_latency:.4f} s") | |
| return ( | |
| animal.title(), | |
| breed, | |
| f"{conf:.4f}", | |
| f"{total_latency:.4f} s", | |
| f"{breed} ({conf:.4f})", | |
| "β", | |
| "\n".join(logs) | |
| ) | |
| elif animal == "cat": | |
| breed, conf, t2 = predict_cat_breed(image) | |
| logs.append(f"[Cat Breed Model] {breed} ({conf:.4f}) | {t2:.4f} s") | |
| total_latency = (time.time() - total_start) | |
| logs.append(f"[Total Pipeline Latency] {total_latency:.4f} s") | |
| return ( | |
| animal.title(), | |
| breed, | |
| f"{conf:.4f}", | |
| f"{total_latency:.4f} s", | |
| "β", | |
| f"{breed} ({conf:.4f})", | |
| "\n".join(logs) | |
| ) | |
| # OTHER β run both | |
| else: | |
| d_breed, d_conf, d_t = predict_dog_breed(image) | |
| c_breed, c_conf, c_t = predict_cat_breed(image) | |
| logs.append(f"[Fallback] Dog Model β {d_breed} ({d_conf:.4f}) | {d_t:.4f} s") | |
| logs.append(f"[Fallback] Cat Model β {c_breed} ({c_conf:.4f}) | {c_t:.4f} s") | |
| primary_breed = d_breed if d_conf > c_conf else c_breed | |
| primary_conf = max(d_conf, c_conf) | |
| total_latency = (time.time() - total_start) | |
| logs.append(f"[Total Pipeline Latency] {total_latency:.4f} s") | |
| return ( | |
| "Other", | |
| primary_breed, | |
| f"{primary_conf:.4f}", | |
| f"{total_latency:.4f} s", | |
| f"{d_breed} ({d_conf:.4f})", | |
| f"{c_breed} ({c_conf:.4f})", | |
| "\n".join(logs) | |
| ) | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="PawCare AI - Pet Identification") as demo: | |
| gr.Markdown("# πΎ PawCare AI β Pet Type & Breed Classifier") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Image(type="pil", label="Upload Pet Image", height=350) | |
| btn = gr.Button("Run Analysis", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Prediction Summary") | |
| with gr.Row(): | |
| animal_box = gr.Textbox(label="Animal Type", interactive=False) | |
| breed_box = gr.Textbox(label="Primary Predicted Breed", interactive=False) | |
| with gr.Row(): | |
| conf_box = gr.Textbox(label="Primary Confidence", interactive=False) | |
| latency_box = gr.Textbox(label="Total Latency", interactive=False) | |
| with gr.Row(): | |
| dog_box = gr.Textbox(label="Dog Model Output", interactive=False) | |
| cat_box = gr.Textbox(label="Cat Model Output", interactive=False) | |
| with gr.Accordion("Detailed Logs (Technical)", open=False): | |
| logs = gr.Textbox(lines=12, interactive=False) | |
| btn.click( | |
| run_pipeline, | |
| inputs=inp, | |
| outputs=[animal_box, breed_box, conf_box, latency_box, dog_box, cat_box, logs] | |
| ) | |
| demo.launch(share=True) | |