Upload 2 files
Browse files- app.py +2 -2
- model_utils.py +10 -3
app.py
CHANGED
|
@@ -59,6 +59,6 @@ demo = gr.Interface(
|
|
| 59 |
)
|
| 60 |
|
| 61 |
if __name__ == "__main__":
|
| 62 |
-
demo.launch(share=
|
| 63 |
else:
|
| 64 |
-
app = demo.launch(share=
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
if __name__ == "__main__":
|
| 62 |
+
demo.launch(share=False)
|
| 63 |
else:
|
| 64 |
+
app = demo.launch(share=False)
|
model_utils.py
CHANGED
|
@@ -111,9 +111,16 @@ class GPT(nn.Module):
|
|
| 111 |
def load_model(model_path):
|
| 112 |
"""Load the trained model"""
|
| 113 |
try:
|
| 114 |
-
torch.
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
model = GPT(config)
|
| 118 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 119 |
model.eval()
|
|
|
|
| 111 |
def load_model(model_path):
|
| 112 |
"""Load the trained model"""
|
| 113 |
try:
|
| 114 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 115 |
+
|
| 116 |
+
# Create config from the saved dictionary
|
| 117 |
+
config_dict = checkpoint['config']
|
| 118 |
+
if isinstance(config_dict, str):
|
| 119 |
+
# If config was saved as string, parse it to dict
|
| 120 |
+
import ast
|
| 121 |
+
config_dict = ast.literal_eval(config_dict)
|
| 122 |
+
config = GPTConfig(**config_dict)
|
| 123 |
+
|
| 124 |
model = GPT(config)
|
| 125 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 126 |
model.eval()
|