Commit ·
2182d48
1
Parent(s): eb03e62
fix
Browse files
app.py
CHANGED
|
@@ -127,22 +127,19 @@ def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]:
|
|
| 127 |
return "Error: Failed to preprocess image", {}
|
| 128 |
|
| 129 |
with torch.no_grad():
|
| 130 |
-
# Move input to same device as model
|
| 131 |
input_tensor = input_tensor.to(DEVICE)
|
| 132 |
output = model(input_tensor)
|
| 133 |
-
# Apply softmax to get probabilities
|
| 134 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
| 135 |
|
| 136 |
# Get predictions and confidences
|
| 137 |
top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
|
| 138 |
|
| 139 |
-
#
|
| 140 |
confidences = {
|
| 141 |
-
CLASS_NAMES[idx.item()]: float(
|
| 142 |
for prob, idx in zip(top_5_probs, top_5_indices)
|
| 143 |
}
|
| 144 |
|
| 145 |
-
# Get top prediction
|
| 146 |
predicted_class = CLASS_NAMES[top_5_indices[0].item()]
|
| 147 |
|
| 148 |
return predicted_class, confidences
|
|
@@ -166,23 +163,40 @@ def get_example_list() -> list:
|
|
| 166 |
|
| 167 |
# Create Gradio interface with error handling
|
| 168 |
try:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
gr.
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f"Error creating Gradio interface: {str(e)}")
|
| 188 |
raise
|
|
|
|
| 127 |
return "Error: Failed to preprocess image", {}
|
| 128 |
|
| 129 |
with torch.no_grad():
|
|
|
|
| 130 |
input_tensor = input_tensor.to(DEVICE)
|
| 131 |
output = model(input_tensor)
|
|
|
|
| 132 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
| 133 |
|
| 134 |
# Get predictions and confidences
|
| 135 |
top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
|
| 136 |
|
| 137 |
+
# Format confidences with exactly 2 decimal places
|
| 138 |
confidences = {
|
| 139 |
+
CLASS_NAMES[idx.item()]: "{:.2f}".format(float(prob.item() * 100))
|
| 140 |
for prob, idx in zip(top_5_probs, top_5_indices)
|
| 141 |
}
|
| 142 |
|
|
|
|
| 143 |
predicted_class = CLASS_NAMES[top_5_indices[0].item()]
|
| 144 |
|
| 145 |
return predicted_class, confidences
|
|
|
|
| 163 |
|
| 164 |
# Create Gradio interface with error handling
|
| 165 |
try:
|
| 166 |
+
with gr.Blocks(theme=gr.themes.Base()) as iface:
|
| 167 |
+
gr.Markdown("# Image Classification with ResNet50")
|
| 168 |
+
gr.Markdown("Upload an image to classify. The model will predict the class and show top 5 confidence scores.")
|
| 169 |
+
|
| 170 |
+
with gr.Row():
|
| 171 |
+
with gr.Column(scale=1):
|
| 172 |
+
input_image = gr.Image(type="numpy", label="Upload Image")
|
| 173 |
+
predict_btn = gr.Button("Predict")
|
| 174 |
+
|
| 175 |
+
with gr.Column(scale=1):
|
| 176 |
+
output_label = gr.Label(label="Predicted Class", num_top_classes=1)
|
| 177 |
+
confidence_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
|
| 178 |
+
|
| 179 |
+
# Add examples
|
| 180 |
+
gr.Examples(
|
| 181 |
+
examples=get_example_list(),
|
| 182 |
+
inputs=input_image,
|
| 183 |
+
outputs=[output_label, confidence_label],
|
| 184 |
+
fn=predict,
|
| 185 |
+
cache_examples=True
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Set up prediction event
|
| 189 |
+
predict_btn.click(
|
| 190 |
+
fn=predict,
|
| 191 |
+
inputs=input_image,
|
| 192 |
+
outputs=[output_label, confidence_label]
|
| 193 |
+
)
|
| 194 |
+
input_image.change(
|
| 195 |
+
fn=predict,
|
| 196 |
+
inputs=input_image,
|
| 197 |
+
outputs=[output_label, confidence_label]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
except Exception as e:
|
| 201 |
logger.error(f"Error creating Gradio interface: {str(e)}")
|
| 202 |
raise
|