Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from torchvision.models import resnet50 | |
| import os | |
| import logging | |
| from typing import Optional, Union | |
| import numpy as np | |
| from pathlib import Path | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Directory Configuration | |
| BASE_DIR = Path(__file__).resolve().parent | |
| MODELS_DIR = BASE_DIR / "models" | |
| EXAMPLES_DIR = BASE_DIR / "examples" | |
| STATIC_DIR = BASE_DIR / "static" / "uploaded" | |
| # Ensure directories exist | |
| STATIC_DIR.mkdir(parents=True, exist_ok=True) | |
| # Global variables | |
| MODEL_PATH = MODELS_DIR / "resnet_50.pth" | |
| CLASSES_PATH = BASE_DIR / "classes.txt" | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def load_class_labels() -> Optional[list]: | |
| """ | |
| Load class labels from the classes.txt file | |
| """ | |
| try: | |
| if not CLASSES_PATH.exists(): | |
| raise FileNotFoundError(f"Classes file not found at {CLASSES_PATH}") | |
| with open(CLASSES_PATH, 'r') as f: | |
| return [line.strip() for line in f.readlines()] | |
| except Exception as e: | |
| logger.error(f"Error loading class labels: {str(e)}") | |
| return None | |
| # Load class labels | |
| CLASS_NAMES = load_class_labels() | |
| if CLASS_NAMES is None: | |
| raise RuntimeError("Failed to load class labels from classes.txt") | |
| # Cache the model to avoid reloading for each prediction | |
| model = None | |
| def load_model() -> Optional[torch.nn.Module]: | |
| """ | |
| Load the ResNet50 model with error handling | |
| """ | |
| global model | |
| try: | |
| if model is not None: | |
| return model | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") | |
| logger.info(f"Loading model on {DEVICE}") | |
| model = resnet50(pretrained=False) | |
| model.fc = torch.nn.Linear(model.fc.in_features, len(CLASS_NAMES)) | |
| # Load the model weights | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| if 'state_dict' in state_dict: | |
| state_dict = state_dict['state_dict'] | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| logger.info("Model loaded successfully") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return None | |
| def preprocess_image(image: Union[np.ndarray, Image.Image]) -> Optional[torch.Tensor]: | |
| """ | |
| Preprocess the input image with error handling | |
| """ | |
| try: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| return transform(image).unsqueeze(0).to(DEVICE) | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image: {str(e)}") | |
| return None | |
| def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]: | |
| """ | |
| Make predictions on the input image with comprehensive error handling | |
| Returns the predicted class and top 5 confidence scores | |
| """ | |
| try: | |
| if image is None: | |
| return "Error: No image provided", {} | |
| model = load_model() | |
| if model is None: | |
| return "Error: Failed to load model", {} | |
| input_tensor = preprocess_image(image) | |
| if input_tensor is None: | |
| return "Error: Failed to preprocess image", {} | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| predicted_class_idx = torch.argmax(probabilities).item() | |
| predicted_class = CLASS_NAMES[predicted_class_idx] | |
| # Get top 5 predictions | |
| top_5_probs, top_5_indices = torch.topk(probabilities, k=5) | |
| # Create confidence dictionary for top 5 classes | |
| confidences = { | |
| CLASS_NAMES[idx.item()]: float(prob.item()) | |
| for prob, idx in zip(top_5_probs, top_5_indices) | |
| } | |
| return predicted_class, confidences | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| return f"Error during prediction: {str(e)}", {} | |
| def get_example_list() -> list: | |
| """ | |
| Get list of example images from the examples directory | |
| """ | |
| try: | |
| examples = [] | |
| for ext in ['.jpg', '.jpeg', '.png']: | |
| examples.extend(list(EXAMPLES_DIR.glob(f'*.{ext}'))) | |
| return [[str(ex)] for ex in sorted(examples)] | |
| except Exception as e: | |
| logger.error(f"Error loading examples: {str(e)}") | |
| return [] | |
| # Create Gradio interface with error handling | |
| try: | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy", label="Upload Image"), | |
| outputs=[ | |
| gr.Label(label="Predicted Class", num_top_classes=1), | |
| gr.Label(label="Top 5 Predictions", num_top_classes=5) | |
| ], | |
| title="Image Classification with ResNet50", | |
| description=( | |
| "Upload an image to classify:\n" | |
| "The model will predict the class and show top 5 confidence scores." | |
| ), | |
| examples=get_example_list(), | |
| cache_examples=True, | |
| theme=gr.themes.Base() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating Gradio interface: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| try: | |
| load_model() # Pre-load the model | |
| iface.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error launching application: {str(e)}") |