Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,6 +27,7 @@ model.eval() # Set model to evaluation mode
|
|
| 27 |
try:
|
| 28 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 29 |
model.load_state_dict(state_dict)
|
|
|
|
| 30 |
except RuntimeError as e:
|
| 31 |
print("Error loading state_dict:", e)
|
| 32 |
print("Ensure that the saved model architecture matches ResNet50.")
|
|
@@ -43,32 +44,43 @@ preprocess = transforms.Compose([
|
|
| 43 |
])
|
| 44 |
|
| 45 |
# Load labels
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# Function to predict image class
|
| 50 |
def predict(image):
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Set up the Gradio interface
|
| 74 |
iface = gr.Interface(
|
|
|
|
| 27 |
try:
|
| 28 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 29 |
model.load_state_dict(state_dict)
|
| 30 |
+
print("Model loaded successfully.")
|
| 31 |
except RuntimeError as e:
|
| 32 |
print("Error loading state_dict:", e)
|
| 33 |
print("Ensure that the saved model architecture matches ResNet50.")
|
|
|
|
| 44 |
])
|
| 45 |
|
| 46 |
# Load labels
|
| 47 |
+
try:
|
| 48 |
+
with open("config.json") as f:
|
| 49 |
+
labels = json.load(f)
|
| 50 |
+
print("Labels loaded successfully.")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print("Error loading labels:", e)
|
| 53 |
|
| 54 |
# Function to predict image class
|
| 55 |
def predict(image):
|
| 56 |
+
try:
|
| 57 |
+
# Convert the uploaded file to a PIL image
|
| 58 |
+
input_image = image.convert("RGB")
|
| 59 |
+
|
| 60 |
+
# Preprocess the image
|
| 61 |
+
input_tensor = preprocess(input_image)
|
| 62 |
+
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
|
| 63 |
+
|
| 64 |
+
# Check if a GPU is available and move the input and model to GPU
|
| 65 |
+
if torch.cuda.is_available():
|
| 66 |
+
input_batch = input_batch.to('cuda')
|
| 67 |
+
model.to('cuda')
|
| 68 |
+
else:
|
| 69 |
+
print("GPU not available, using CPU.")
|
| 70 |
+
|
| 71 |
+
# Perform inference
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
output = model(input_batch)
|
| 74 |
+
|
| 75 |
+
# Get the predicted class with the highest score
|
| 76 |
+
_, predicted_idx = torch.max(output, 1)
|
| 77 |
+
predicted_class = labels[str(predicted_idx.item())]
|
| 78 |
|
| 79 |
+
return f"Predicted class: {predicted_class}"
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error during prediction: {e}")
|
| 83 |
+
return "An error occurred during prediction. Please try again."
|
| 84 |
|
| 85 |
# Set up the Gradio interface
|
| 86 |
iface = gr.Interface(
|