Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| import os | |
| # Define the model architecture | |
| class BacterialMorphologyClassifier(nn.Module): | |
| def __init__(self): | |
| super(BacterialMorphologyClassifier, self).__init__() | |
| self.feature_extractor = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| ) | |
| self.fc = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(64 * 56 * 56, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, 3), | |
| nn.Softmax(dim=1), | |
| ) | |
| def forward(self, x): | |
| x = self.feature_extractor(x) | |
| x = self.fc(x) | |
| return x | |
| # Load the model | |
| MODEL_PATH = "model.pth" | |
| model = BacterialMorphologyClassifier() | |
| try: | |
| # Download the model if it doesn't exist | |
| if not os.path.exists(MODEL_PATH): | |
| print("Downloading the model...") | |
| url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" | |
| response = requests.get(url) | |
| with open(MODEL_PATH, "wb") as f: | |
| f.write(response.content) | |
| # Load the model weights | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading the model: {e}") | |
| # Define image preprocessing to match training preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to match model input size | |
| transforms.ToTensor(), # Convert to a tensor | |
| transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1] | |
| ]) | |
| # Prediction function | |
| def predict(image): | |
| try: | |
| # Convert the image to a tensor | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Perform prediction | |
| with torch.no_grad(): # Ensure no gradients are calculated | |
| output = model(image_tensor) | |
| # Class mapping | |
| class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} | |
| # Return the predicted class and confidence | |
| predicted_class = class_labels[output.argmax().item()] | |
| confidence = output.max().item() # Softmax value as confidence | |
| return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Define example images | |
| examples = [ | |
| ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"], | |
| ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"], | |
| ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"], | |
| ] | |
| # Set up Gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Text(label="Prediction"), | |
| title="Bacterial Morphology Classification", | |
| description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.", | |
| examples=examples, | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |