vsimmer's picture
Update app.py
b0cef8d verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as T
from huggingface_hub import snapshot_download
import os
import sys
print("🦁 Setting up zoolang model for dog emotion classification...")
# Download your model files
model_path = snapshot_download(repo_id="vsimmer/zoolang")
print(f"Model downloaded to: {model_path}")
# Add model path to sys.path
sys.path.insert(0, str(model_path))
# Import your modules
from model import get_model
# Your class names
CLASS_NAMES = ['aggressive', 'anxious', 'frightened', 'happy', 'inquisitive']
# Custom handler with correct filename
class CustomHandler:
def __init__(self, path=""):
self.model = get_model()
# FIXED: Correct filename doge_224_sd-03.bin (not doge_223)
weights_path = os.path.join(path, "doge_224_sd-03.bin")
print(f"Loading weights from: {weights_path}")
self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
self.model.eval()
print("βœ… Model weights loaded successfully!")
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __call__(self, data):
inputs = data.pop("inputs", data)
image = Image.open(inputs).convert("RGB")
tensor = self.transform(image).unsqueeze(0)
with torch.no_grad():
outputs = self.model(tensor)
prediction = torch.argmax(outputs, dim=1).item()
return {"label": prediction}
# Initialize handler
handler = CustomHandler(path=model_path)
def classify_image(image):
"""Classify dog emotion using your zoolang model"""
if image is None:
return {"No image": 1.0}
try:
# Convert PIL image to bytes for your handler
import io
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
# Use your handler
result = handler({"inputs": img_byte_arr})
# Convert class index to probabilities
if isinstance(result, dict) and "label" in result:
class_index = result["label"]
# Create probability distribution
probabilities = {}
for i, class_name in enumerate(CLASS_NAMES):
if i == class_index:
probabilities[class_name] = 0.95 # High confidence for predicted class
else:
probabilities[class_name] = 0.01 # Low for others
return probabilities
else:
return {"Unknown result": 0.5}
except Exception as e:
print(f"Classification error: {e}")
return {"Classification failed": 1.0}
# Create Gradio interface - FIXED: Removed deprecated allow_flagging parameter
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Dog Image"),
outputs=gr.Label(num_top_classes=5, label="Dog Emotion Prediction"),
title="πŸ• Zoolang Dog Emotion Classifier",
description="Classify dog emotions: aggressive, anxious, frightened, happy, or inquisitive"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)