Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # 1. SETUP MODEL ARCHITECTURE | |
| # (Matches the ResNet50 logs you provided) | |
| model = models.resnet50(weights=None) | |
| # 2. MATCH THE FINAL LAYER | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 2) | |
| # 3. LOAD WEIGHTS | |
| # Ensure this matches the EXACT filename you uploaded to the Files tab | |
| model_path = "fire_detection_resnet50.pth" | |
| try: | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| print("Model weights loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model weights: {e}") | |
| model.eval() | |
| # 4. DEFINE PREPROCESSING | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # 5. PREDICTION FUNCTION | |
| labels = ['Non-Fire', 'Fire'] | |
| def predict(image): | |
| if image is None: | |
| return None | |
| try: | |
| # Preprocess | |
| image = image.convert('RGB') | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # Return dictionary for Gradio Label | |
| return {labels[i]: float(probabilities[i]) for i in range(len(labels))} | |
| except Exception as e: | |
| return {f"Error: {str(e)}": 0.0} | |
| # 6. LAUNCH GRADIO UI | |
| # Removed 'examples' to prevent the crash | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), | |
| title="Fire Detection System", | |
| description="Upload an image to detect if fire is present." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |