Update Geo/GeochatP-main/app.py
Browse files- Geo/GeochatP-main/app.py +2 -2
Geo/GeochatP-main/app.py
CHANGED
|
@@ -14,7 +14,7 @@ class MyModel(torch.nn.Module):
|
|
| 14 |
return x
|
| 15 |
|
| 16 |
model = MyModel()
|
| 17 |
-
model.load_state_dict(torch.load("model.pth"))
|
| 18 |
model.eval()
|
| 19 |
|
| 20 |
# Define image preprocessing
|
|
@@ -31,5 +31,5 @@ def predict(image):
|
|
| 31 |
return output.numpy().tolist()
|
| 32 |
|
| 33 |
# Create Gradio interface
|
| 34 |
-
iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="json")
|
| 35 |
iface.launch()
|
|
|
|
| 14 |
return x
|
| 15 |
|
| 16 |
model = MyModel()
|
| 17 |
+
model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
|
| 18 |
model.eval()
|
| 19 |
|
| 20 |
# Define image preprocessing
|
|
|
|
| 31 |
return output.numpy().tolist()
|
| 32 |
|
| 33 |
# Create Gradio interface
|
| 34 |
+
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="json")
|
| 35 |
iface.launch()
|