Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import cv2 | |
| import numpy as np | |
| # -------------------- | |
| # Model Definition | |
| # -------------------- | |
| class FireCNN(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super(FireCNN, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, padding=1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 8 * 8, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| # -------------------- | |
| # Load Model | |
| # -------------------- | |
| checkpoint = torch.load("fire_model.pth", map_location="cpu") | |
| model = FireCNN() | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| IMG_SIZE = checkpoint["img_size"] | |
| # -------------------- | |
| # Prediction Function | |
| # -------------------- | |
| def predict(image): | |
| img = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) | |
| img = img / 255.0 | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.softmax(outputs, dim=1).squeeze().numpy() | |
| return { | |
| "fire": float(probs[0]), | |
| "smoke": float(probs[1]), | |
| "non_fire": float(probs[2]) | |
| } | |
| # -------------------- | |
| # Gradio Interface | |
| # -------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="🔥 Fire / Smoke Detection", | |
| description="Upload an image to detect Fire, Smoke, or Non-Fire" | |
| ) | |
| demo.launch() |