import gradio as gr import torch import torch.nn.functional as F from PIL import Image import numpy as np from torchvision import transforms import os # Import model classes from model import EfficientNet class DogCatClassifier: def __init__(self, model_path="efficientnet_b1_dogcat.pth"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model self.model = self._load_model(model_path) self.model.eval() # Define transforms self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def _load_model(self, model_path): # Create model architecture model = EfficientNet(model_name="efficient_b1", num_classes=2, pretrained=False) # Load state dict if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=self.device) model.load_state_dict(state_dict) print(f"Model loaded from {model_path}") else: raise FileNotFoundError(f"Model file not found: {model_path}") model.to(self.device) return model def predict(self, image): try: # Handle None input if image is None: return "Please upload an image" # Preprocess image if isinstance(image, str): image = Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): outputs = self.model(image_tensor) probabilities = F.softmax(outputs, dim=1) # Get probabilities for each class cat_prob = probabilities[0][0].item() dog_prob = probabilities[0][1].item() if cat_prob > dog_prob: result = f"🐱 Cat ({cat_prob:.2%})" else: result = f"🐶 Dog ({dog_prob:.2%})" return result except Exception as e: print(f"Error during prediction: {e}") return "Error - please try again" # Initialize classifier classifier = DogCatClassifier() def classify_image(image): """Classify uploaded image as Cat or Dog""" return classifier.predict(image) # Create minimal Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Textbox(), title="Cat vs Dog Classifier", description="Upload an image to classify if it's a cat or dog." ) if __name__ == "__main__": iface.launch()