fire-detection / app.py
gnikhilchand's picture
Update app.py
2626329 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# 1. SETUP MODEL ARCHITECTURE
# (Matches the ResNet50 logs you provided)
model = models.resnet50(weights=None)
# 2. MATCH THE FINAL LAYER
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# 3. LOAD WEIGHTS
# Ensure this matches the EXACT filename you uploaded to the Files tab
model_path = "fire_detection_resnet50.pth"
try:
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
print("Model weights loaded successfully.")
except Exception as e:
print(f"Error loading model weights: {e}")
model.eval()
# 4. DEFINE PREPROCESSING
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 5. PREDICTION FUNCTION
labels = ['Non-Fire', 'Fire']
def predict(image):
if image is None:
return None
try:
# Preprocess
image = image.convert('RGB')
image_tensor = transform(image).unsqueeze(0)
# Inference
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Return dictionary for Gradio Label
return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
except Exception as e:
return {f"Error: {str(e)}": 0.0}
# 6. LAUNCH GRADIO UI
# Removed 'examples' to prevent the crash
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="Fire Detection System",
description="Upload an image to detect if fire is present."
)
if __name__ == "__main__":
interface.launch()