Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ class_names = [line.strip() for line in open("classes.txt")]
|
|
| 20 |
# Load the model
|
| 21 |
model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
|
| 22 |
model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
|
| 23 |
-
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth')
|
| 24 |
model.load_state_dict(checkpoint, strict=False)
|
| 25 |
model = model.to(device)
|
| 26 |
model.eval()
|
|
@@ -69,7 +69,7 @@ iface = gr.Interface(
|
|
| 69 |
fn=predict,
|
| 70 |
inputs=gr.Image(type="numpy"),
|
| 71 |
outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
|
| 72 |
-
examples=[
|
| 73 |
)
|
| 74 |
|
| 75 |
# Launch the Gradio app
|
|
|
|
| 20 |
# Load the model
|
| 21 |
model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
|
| 22 |
model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
|
| 23 |
+
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu'))
|
| 24 |
model.load_state_dict(checkpoint, strict=False)
|
| 25 |
model = model.to(device)
|
| 26 |
model.eval()
|
|
|
|
| 69 |
fn=predict,
|
| 70 |
inputs=gr.Image(type="numpy"),
|
| 71 |
outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
|
| 72 |
+
examples=["Image 1.jpg", "Image 2"] # Optional: Add paths to example images
|
| 73 |
)
|
| 74 |
|
| 75 |
# Launch the Gradio app
|