import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # 1. SETUP MODEL ARCHITECTURE # (Matches the ResNet50 logs you provided) model = models.resnet50(weights=None) # 2. MATCH THE FINAL LAYER num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # 3. LOAD WEIGHTS # Ensure this matches the EXACT filename you uploaded to the Files tab model_path = "fire_detection_resnet50.pth" try: state_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict) print("Model weights loaded successfully.") except Exception as e: print(f"Error loading model weights: {e}") model.eval() # 4. DEFINE PREPROCESSING transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 5. PREDICTION FUNCTION labels = ['Non-Fire', 'Fire'] def predict(image): if image is None: return None try: # Preprocess image = image.convert('RGB') image_tensor = transform(image).unsqueeze(0) # Inference with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Return dictionary for Gradio Label return {labels[i]: float(probabilities[i]) for i in range(len(labels))} except Exception as e: return {f"Error: {str(e)}": 0.0} # 6. LAUNCH GRADIO UI # Removed 'examples' to prevent the crash interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Label(num_top_classes=2, label="Prediction"), title="Fire Detection System", description="Upload an image to detect if fire is present." ) if __name__ == "__main__": interface.launch()