Update app.py
Browse files
app.py
CHANGED
|
@@ -5,11 +5,14 @@ import torch
|
|
| 5 |
import torchvision.transforms as T
|
| 6 |
import gradio as gr
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
# Load parameters and model
|
| 9 |
with open("parameters.json", "r") as f:
|
| 10 |
parameters = json.load(f)
|
| 11 |
|
| 12 |
-
model = create_model(parameters)
|
| 13 |
weights = torch.load("cxp_projection_rotation.pt", map_location="cpu")
|
| 14 |
model.load_state_dict(weights)
|
| 15 |
model.eval()
|
|
|
|
| 5 |
import torchvision.transforms as T
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
+
# Import your model creation function
|
| 9 |
+
from model import create_model # ここで create_model を定義しているファイルを指定してください
|
| 10 |
+
|
| 11 |
# Load parameters and model
|
| 12 |
with open("parameters.json", "r") as f:
|
| 13 |
parameters = json.load(f)
|
| 14 |
|
| 15 |
+
model = create_model(parameters)
|
| 16 |
weights = torch.load("cxp_projection_rotation.pt", map_location="cpu")
|
| 17 |
model.load_state_dict(weights)
|
| 18 |
model.eval()
|