ramagururadhakrishnan commited on
Commit
4fcd204
·
verified ·
1 Parent(s): 0019e56

Labels Added

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -2,20 +2,29 @@ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
 
5
  from src.model import SwinTransformerMultiLabel # Import from src folder
6
 
7
  # Title and description
8
  st.title("STAR Multi-Label Classifier")
9
- st.write("Upload an image to get predictions.")
 
 
 
 
 
 
 
 
 
10
 
11
  # Load trained model
12
- NUM_CLASSES = 18
13
- model_path = "models/star.pth"
14
  model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
15
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
16
  model.eval()
17
 
18
- # Define transformation
19
  transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
@@ -32,7 +41,8 @@ if uploaded_file is not None:
32
  img_tensor = transform(image).unsqueeze(0)
33
  with torch.no_grad():
34
  output = model(img_tensor)
35
- predicted_labels = (output > 0.5).int().tolist()[0]
 
36
 
37
- # Display Results
38
- st.write("✅ **Predicted Labels:**", predicted_labels)
 
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
+ import json
6
  from src.model import SwinTransformerMultiLabel # Import from src folder
7
 
8
  # Title and description
9
  st.title("STAR Multi-Label Classifier")
10
+ st.write("Upload an image to get multi-label predictions.")
11
+
12
+ # Load class labels from JSON
13
+ label_file = "data/labels.json" # Path to labels file
14
+ with open(label_file, "r") as f:
15
+ label_data = json.load(f)
16
+
17
+ # Extract unique class labels
18
+ class_labels = sorted(set(tag for tags in label_data.values() for tag in tags))
19
+ NUM_CLASSES = len(class_labels)
20
 
21
  # Load trained model
22
+ model_path = "models/star.pth"
 
23
  model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
24
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
25
  model.eval()
26
 
27
+ # Define image preprocessing transformations
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
 
41
  img_tensor = transform(image).unsqueeze(0)
42
  with torch.no_grad():
43
  output = model(img_tensor)
44
+ predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5] # Threshold = 0.5
45
+ predicted_labels = [class_labels[i] for i in predicted_indices] # Convert indices to labels
46
 
47
+ # Display results
48
+ st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No labels detected")