gnikhilchand commited on
Commit
2626329
·
verified ·
1 Parent(s): d73f407

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -5,27 +5,27 @@ from torchvision import models, transforms
5
  from PIL import Image
6
 
7
  # 1. SETUP MODEL ARCHITECTURE
8
- # Based on your logs, you used ResNet50.
9
- # If you actually used ResNet18, change 'resnet50' to 'resnet18' below.
10
  model = models.resnet50(weights=None)
11
 
12
  # 2. MATCH THE FINAL LAYER
13
- # ResNet50 has 2048 input features in the final layer.
14
- # (If you used ResNet18, this number would be 512).
15
  num_ftrs = model.fc.in_features
16
  model.fc = nn.Linear(num_ftrs, 2)
17
 
18
  # 3. LOAD WEIGHTS
19
- # Replace 'fire_detection_resnet18.pth' with the EXACT name of the file you uploaded
20
  model_path = "fire_detection_resnet50.pth"
21
 
22
- # Load weights for CPU (since HF Spaces Free Tier uses CPU)
23
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
24
- model.load_state_dict(state_dict)
 
 
 
 
25
  model.eval()
26
 
27
  # 4. DEFINE PREPROCESSING
28
- # This must match what you used during training
29
  transform = transforms.Compose([
30
  transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
@@ -33,32 +33,35 @@ transform = transforms.Compose([
33
  ])
34
 
35
  # 5. PREDICTION FUNCTION
36
- labels = ['Non-Fire', 'Fire'] # 0 is Non-Fire, 1 is Fire
37
 
38
  def predict(image):
39
  if image is None:
40
  return None
41
 
42
- # Preprocess
43
- image = image.convert('RGB')
44
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
45
-
46
- # Inference
47
- with torch.no_grad():
48
- outputs = model(image_tensor)
49
- probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
50
-
51
- # Return dictionary for Gradio Label
52
- return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
 
 
 
53
 
54
  # 6. LAUNCH GRADIO UI
 
55
  interface = gr.Interface(
56
  fn=predict,
57
  inputs=gr.Image(type="pil", label="Upload Image"),
58
  outputs=gr.Label(num_top_classes=2, label="Prediction"),
59
- title="Fire Detection System tements",
60
- description="Upload an image to detect if fire is present. (Model: ResNet50)",
61
- examples=["fire.jpg", "forest.jpg"] # Optional: Upload these images to your space for users to click
62
  )
63
 
64
  if __name__ == "__main__":
 
5
  from PIL import Image
6
 
7
  # 1. SETUP MODEL ARCHITECTURE
8
+ # (Matches the ResNet50 logs you provided)
 
9
  model = models.resnet50(weights=None)
10
 
11
  # 2. MATCH THE FINAL LAYER
 
 
12
  num_ftrs = model.fc.in_features
13
  model.fc = nn.Linear(num_ftrs, 2)
14
 
15
  # 3. LOAD WEIGHTS
16
+ # Ensure this matches the EXACT filename you uploaded to the Files tab
17
  model_path = "fire_detection_resnet50.pth"
18
 
19
+ try:
20
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
21
+ model.load_state_dict(state_dict)
22
+ print("Model weights loaded successfully.")
23
+ except Exception as e:
24
+ print(f"Error loading model weights: {e}")
25
+
26
  model.eval()
27
 
28
  # 4. DEFINE PREPROCESSING
 
29
  transform = transforms.Compose([
30
  transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
 
33
  ])
34
 
35
  # 5. PREDICTION FUNCTION
36
+ labels = ['Non-Fire', 'Fire']
37
 
38
  def predict(image):
39
  if image is None:
40
  return None
41
 
42
+ try:
43
+ # Preprocess
44
+ image = image.convert('RGB')
45
+ image_tensor = transform(image).unsqueeze(0)
46
+
47
+ # Inference
48
+ with torch.no_grad():
49
+ outputs = model(image_tensor)
50
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
51
+
52
+ # Return dictionary for Gradio Label
53
+ return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
54
+ except Exception as e:
55
+ return {f"Error: {str(e)}": 0.0}
56
 
57
  # 6. LAUNCH GRADIO UI
58
+ # Removed 'examples' to prevent the crash
59
  interface = gr.Interface(
60
  fn=predict,
61
  inputs=gr.Image(type="pil", label="Upload Image"),
62
  outputs=gr.Label(num_top_classes=2, label="Prediction"),
63
+ title="Fire Detection System",
64
+ description="Upload an image to detect if fire is present."
 
65
  )
66
 
67
  if __name__ == "__main__":