import gradio as gr import torch from PIL import Image import torchvision.transforms as T from huggingface_hub import snapshot_download import os import sys print("🦁 Setting up zoolang model for dog emotion classification...") # Download your model files model_path = snapshot_download(repo_id="vsimmer/zoolang") print(f"Model downloaded to: {model_path}") # Add model path to sys.path sys.path.insert(0, str(model_path)) # Import your modules from model import get_model # Your class names CLASS_NAMES = ['aggressive', 'anxious', 'frightened', 'happy', 'inquisitive'] # Custom handler with correct filename class CustomHandler: def __init__(self, path=""): self.model = get_model() # FIXED: Correct filename doge_224_sd-03.bin (not doge_223) weights_path = os.path.join(path, "doge_224_sd-03.bin") print(f"Loading weights from: {weights_path}") self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) self.model.eval() print("✅ Model weights loaded successfully!") self.transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __call__(self, data): inputs = data.pop("inputs", data) image = Image.open(inputs).convert("RGB") tensor = self.transform(image).unsqueeze(0) with torch.no_grad(): outputs = self.model(tensor) prediction = torch.argmax(outputs, dim=1).item() return {"label": prediction} # Initialize handler handler = CustomHandler(path=model_path) def classify_image(image): """Classify dog emotion using your zoolang model""" if image is None: return {"No image": 1.0} try: # Convert PIL image to bytes for your handler import io img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) # Use your handler result = handler({"inputs": img_byte_arr}) # Convert class index to probabilities if isinstance(result, dict) and "label" in result: class_index = result["label"] # Create probability distribution probabilities = {} for i, class_name in enumerate(CLASS_NAMES): if i == class_index: probabilities[class_name] = 0.95 # High confidence for predicted class else: probabilities[class_name] = 0.01 # Low for others return probabilities else: return {"Unknown result": 0.5} except Exception as e: print(f"Classification error: {e}") return {"Classification failed": 1.0} # Create Gradio interface - FIXED: Removed deprecated allow_flagging parameter iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Dog Image"), outputs=gr.Label(num_top_classes=5, label="Dog Emotion Prediction"), title="🐕 Zoolang Dog Emotion Classifier", description="Classify dog emotions: aggressive, anxious, frightened, happy, or inquisitive" ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)