gnikhilchand commited on
Commit
1a75b1d
·
verified ·
1 Parent(s): f82794c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+
7
+ # 1. SETUP MODEL ARCHITECTURE
8
+ # Based on your logs, you used ResNet50.
9
+ # If you actually used ResNet18, change 'resnet50' to 'resnet18' below.
10
+ model = models.resnet50(weights=None)
11
+
12
+ # 2. MATCH THE FINAL LAYER
13
+ # ResNet50 has 2048 input features in the final layer.
14
+ # (If you used ResNet18, this number would be 512).
15
+ num_ftrs = model.fc.in_features
16
+ model.fc = nn.Linear(num_ftrs, 2)
17
+
18
+ # 3. LOAD WEIGHTS
19
+ # Replace 'fire_detection_resnet18.pth' with the EXACT name of the file you uploaded
20
+ model_path = "fire_detection_resnet18.pth"
21
+
22
+ # Load weights for CPU (since HF Spaces Free Tier uses CPU)
23
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
24
+ model.load_state_dict(state_dict)
25
+ model.eval()
26
+
27
+ # 4. DEFINE PREPROCESSING
28
+ # This must match what you used during training
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # 5. PREDICTION FUNCTION
36
+ labels = ['Non-Fire', 'Fire'] # 0 is Non-Fire, 1 is Fire
37
+
38
+ def predict(image):
39
+ if image is None:
40
+ return None
41
+
42
+ # Preprocess
43
+ image = image.convert('RGB')
44
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
45
+
46
+ # Inference
47
+ with torch.no_grad():
48
+ outputs = model(image_tensor)
49
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
50
+
51
+ # Return dictionary for Gradio Label
52
+ return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
53
+
54
+ # 6. LAUNCH GRADIO UI
55
+ interface = gr.Interface(
56
+ fn=predict,
57
+ inputs=gr.Image(type="pil", label="Upload Image"),
58
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
59
+ title="Fire Detection System tements",
60
+ description="Upload an image to detect if fire is present. (Model: ResNet50)",
61
+ examples=["fire.jpg", "forest.jpg"] # Optional: Upload these images to your space for users to click
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ interface.launch()