Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| from torchvision import transforms | |
| import os | |
| # Import model classes | |
| from model import EfficientNet | |
| class DogCatClassifier: | |
| def __init__(self, model_path="efficientnet_b1_dogcat.pth"): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| self.model = self._load_model(model_path) | |
| self.model.eval() | |
| # Define transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def _load_model(self, model_path): | |
| # Create model architecture | |
| model = EfficientNet(model_name="efficient_b1", num_classes=2, pretrained=False) | |
| # Load state dict | |
| if os.path.exists(model_path): | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| model.load_state_dict(state_dict) | |
| print(f"Model loaded from {model_path}") | |
| else: | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| model.to(self.device) | |
| return model | |
| def predict(self, image): | |
| try: | |
| # Handle None input | |
| if image is None: | |
| return "Please upload an image" | |
| # Preprocess image | |
| if isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Get probabilities for each class | |
| cat_prob = probabilities[0][0].item() | |
| dog_prob = probabilities[0][1].item() | |
| if cat_prob > dog_prob: | |
| result = f"🐱 Cat ({cat_prob:.2%})" | |
| else: | |
| result = f"🐶 Dog ({dog_prob:.2%})" | |
| return result | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| return "Error - please try again" | |
| # Initialize classifier | |
| classifier = DogCatClassifier() | |
| def classify_image(image): | |
| """Classify uploaded image as Cat or Dog""" | |
| return classifier.predict(image) | |
| # Create minimal Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(), | |
| title="Cat vs Dog Classifier", | |
| description="Upload an image to classify if it's a cat or dog." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |