mathminakshi commited on
Commit
63c41aa
·
verified ·
1 Parent(s): 7743aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -91
app.py CHANGED
@@ -5,10 +5,7 @@ import torchvision.transforms as transforms
5
  import json
6
  import sys
7
  import os
8
-
9
- # Add the parent directory to the Python path to access 'model' from 'HF_app'
10
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
11
-
12
  # Now import from model.py
13
  from model.model import ResNet50
14
 
@@ -17,21 +14,12 @@ from model.model import ResNet50
17
  def load_class_names():
18
  try:
19
  with open("imagenet_classes.json", 'r', encoding='utf-8') as f:
20
- # Read the file content first
21
  content = f.read()
22
- # Try to clean the content of any control characters
23
  content = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t')
24
- # Parse the cleaned content
25
  class_names = json.loads(content)
26
  return class_names
27
- except json.JSONDecodeError as e:
28
- st.error(f"Error loading class names: Invalid JSON format. {str(e)}")
29
- return {}
30
- except FileNotFoundError:
31
- st.error("Error: Class names file not found. Please check the file path.")
32
- return {}
33
  except Exception as e:
34
- st.error(f"Unexpected error loading class names: {str(e)}")
35
  return {}
36
 
37
  # Load model
@@ -40,7 +28,6 @@ def load_model():
40
  try:
41
  model = ResNet50(num_classes=1000)
42
  checkpoint = torch.load("./checkpoints/model_best.pth", map_location=torch.device("cpu"))
43
- # Extract just the model state dict from the checkpoint
44
  if "model_state_dict" in checkpoint:
45
  model.load_state_dict(checkpoint["model_state_dict"])
46
  else:
@@ -71,84 +58,53 @@ if model is None:
71
  st.error("Failed to load the model. Please check the model file.")
72
  st.stop()
73
 
74
- # Create a container for the file uploader
75
- upload_container = st.empty()
 
76
 
77
- # File uploader inside the container
78
- uploaded_file = upload_container.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
79
 
80
- if uploaded_file:
81
- # Clear the upload container
82
- upload_container.empty()
83
-
84
- # Load the image first
85
- image = Image.open(uploaded_file).convert("RGB")
86
-
87
- # Create a container for the entire content
88
- with st.container():
89
- # First row - Headers with reduced width
90
- st.markdown("""
91
- <style>
92
- div[data-testid="column"] {
93
- display: flex;
94
- flex-direction: column;
95
- height: 400px;
96
- justify-content: center;
97
- }
98
- div[data-testid="stImage"] {
99
- height: 400px;
100
- display: flex;
101
- align-items: center;
102
- }
103
- div[data-testid="stTable"] {
104
- height: 400px;
105
- display: flex;
106
- flex-direction: column;
107
- justify-content: center;
108
- }
109
- .header-row {
110
- max-width: 800px;
111
- margin: auto;
112
- }
113
- </style>
114
- """, unsafe_allow_html=True)
115
 
116
- # Wrap headers in a div with the header-row class
117
- st.markdown('<div class="header-row">', unsafe_allow_html=True)
118
- header_col1, header_col2 = st.columns(2)
119
- with header_col1:
120
- st.markdown("### Uploaded Image")
121
- with header_col2:
122
- st.markdown("### Predictions")
123
- st.markdown('</div>', unsafe_allow_html=True)
124
 
125
- # Add a small space between headers and content
126
- st.markdown("<br>", unsafe_allow_html=True)
127
-
128
- # Second row - Content (full width)
129
- content_col1, content_col2 = st.columns(2)
130
-
131
- # Column 1 - Image
132
- with content_col1:
133
- st.image(image, use_container_width=True)
134
-
135
- # Column 2 - Predictions
136
- with content_col2:
137
- # Process image and get predictions
138
- input_tensor = preprocess_image(image)
139
- with torch.no_grad():
140
- outputs = model(input_tensor)
141
- probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
142
- top5_prob, top5_idx = torch.topk(probabilities, 5)
143
-
144
- results = []
145
- for i in range(5):
146
- class_id = top5_idx[i].item()
147
- prob = top5_prob[i].item() * 100
148
- class_name = class_names[str(class_id)]
149
- results.append({
150
- "Rank": i + 1,
151
- "Class": class_name,
152
- "Confidence": f"{prob:.2f}%"
153
- })
154
- st.table(results)
 
 
 
 
 
 
5
  import json
6
  import sys
7
  import os
 
 
8
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 
9
  # Now import from model.py
10
  from model.model import ResNet50
11
 
 
14
  def load_class_names():
15
  try:
16
  with open("imagenet_classes.json", 'r', encoding='utf-8') as f:
 
17
  content = f.read()
 
18
  content = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t')
 
19
  class_names = json.loads(content)
20
  return class_names
 
 
 
 
 
 
21
  except Exception as e:
22
+ st.error(f"Error loading class names: {str(e)}")
23
  return {}
24
 
25
  # Load model
 
28
  try:
29
  model = ResNet50(num_classes=1000)
30
  checkpoint = torch.load("./checkpoints/model_best.pth", map_location=torch.device("cpu"))
 
31
  if "model_state_dict" in checkpoint:
32
  model.load_state_dict(checkpoint["model_state_dict"])
33
  else:
 
58
  st.error("Failed to load the model. Please check the model file.")
59
  st.stop()
60
 
61
+ # Initialize session state
62
+ if 'show_upload' not in st.session_state:
63
+ st.session_state.show_upload = True
64
 
65
+ # Main content container
66
+ main_container = st.empty()
67
 
68
+ with main_container.container():
69
+ if st.session_state.show_upload:
70
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ if uploaded_file:
73
+ # Load and display image
74
+ image = Image.open(uploaded_file).convert("RGB")
 
 
 
 
 
75
 
76
+ col1, col2 = st.columns(2)
77
+
78
+ with col1:
79
+ st.markdown("### Uploaded Image")
80
+ st.image(image, use_container_width=True)
81
+
82
+ with col2:
83
+ st.markdown("### Predictions")
84
+ # Process image and get predictions
85
+ input_tensor = preprocess_image(image)
86
+ with torch.no_grad():
87
+ outputs = model(input_tensor)
88
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
89
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
90
+
91
+ results = []
92
+ for i in range(5):
93
+ class_id = top5_idx[i].item()
94
+ prob = top5_prob[i].item() * 100
95
+ class_name = class_names[str(class_id)]
96
+ results.append({
97
+ "Rank": i + 1,
98
+ "Class": class_name,
99
+ "Confidence": f"{prob:.2f}%"
100
+ })
101
+ st.table(results)
102
+
103
+ # Add the New Image button
104
+ st.markdown("<br>", unsafe_allow_html=True)
105
+ col1, col2, col3 = st.columns([2, 1, 2])
106
+ with col2:
107
+ if st.button("↻ New Image"):
108
+ main_container.empty()
109
+ st.session_state.show_upload = True
110
+ st.rerun()