Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from huggingface_hub import snapshot_download | |
| import os | |
| import sys | |
| print("π¦ Setting up zoolang model for dog emotion classification...") | |
| # Download your model files | |
| model_path = snapshot_download(repo_id="vsimmer/zoolang") | |
| print(f"Model downloaded to: {model_path}") | |
| # Add model path to sys.path | |
| sys.path.insert(0, str(model_path)) | |
| # Import your modules | |
| from model import get_model | |
| # Your class names | |
| CLASS_NAMES = ['aggressive', 'anxious', 'frightened', 'happy', 'inquisitive'] | |
| # Custom handler with correct filename | |
| class CustomHandler: | |
| def __init__(self, path=""): | |
| self.model = get_model() | |
| # FIXED: Correct filename doge_224_sd-03.bin (not doge_223) | |
| weights_path = os.path.join(path, "doge_224_sd-03.bin") | |
| print(f"Loading weights from: {weights_path}") | |
| self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) | |
| self.model.eval() | |
| print("β Model weights loaded successfully!") | |
| self.transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def __call__(self, data): | |
| inputs = data.pop("inputs", data) | |
| image = Image.open(inputs).convert("RGB") | |
| tensor = self.transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = self.model(tensor) | |
| prediction = torch.argmax(outputs, dim=1).item() | |
| return {"label": prediction} | |
| # Initialize handler | |
| handler = CustomHandler(path=model_path) | |
| def classify_image(image): | |
| """Classify dog emotion using your zoolang model""" | |
| if image is None: | |
| return {"No image": 1.0} | |
| try: | |
| # Convert PIL image to bytes for your handler | |
| import io | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| # Use your handler | |
| result = handler({"inputs": img_byte_arr}) | |
| # Convert class index to probabilities | |
| if isinstance(result, dict) and "label" in result: | |
| class_index = result["label"] | |
| # Create probability distribution | |
| probabilities = {} | |
| for i, class_name in enumerate(CLASS_NAMES): | |
| if i == class_index: | |
| probabilities[class_name] = 0.95 # High confidence for predicted class | |
| else: | |
| probabilities[class_name] = 0.01 # Low for others | |
| return probabilities | |
| else: | |
| return {"Unknown result": 0.5} | |
| except Exception as e: | |
| print(f"Classification error: {e}") | |
| return {"Classification failed": 1.0} | |
| # Create Gradio interface - FIXED: Removed deprecated allow_flagging parameter | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", label="Upload Dog Image"), | |
| outputs=gr.Label(num_top_classes=5, label="Dog Emotion Prediction"), | |
| title="π Zoolang Dog Emotion Classifier", | |
| description="Classify dog emotions: aggressive, anxious, frightened, happy, or inquisitive" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |