Upload 2 files
Browse files- app.py +2 -2
- model_utils.py +10 -6
app.py
CHANGED
|
@@ -56,6 +56,6 @@ demo = gr.Interface(
|
|
| 56 |
)
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
-
demo.launch(
|
| 60 |
else:
|
| 61 |
-
app = demo.launch(
|
|
|
|
| 56 |
)
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
+
demo.launch()
|
| 60 |
else:
|
| 61 |
+
app = demo.launch()
|
model_utils.py
CHANGED
|
@@ -110,12 +110,16 @@ class GPT(nn.Module):
|
|
| 110 |
|
| 111 |
def load_model(model_path):
|
| 112 |
"""Load the trained model"""
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def generate_text(model, prompt, max_new_tokens=50, temperature=0.8, top_k=40):
|
| 121 |
"""Generate text based on a prompt
|
|
|
|
| 110 |
|
| 111 |
def load_model(model_path):
|
| 112 |
"""Load the trained model"""
|
| 113 |
+
try:
|
| 114 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=True)
|
| 115 |
+
config = GPTConfig(**checkpoint['config'])
|
| 116 |
+
model = GPT(config)
|
| 117 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 118 |
+
model.eval()
|
| 119 |
+
return model
|
| 120 |
+
except AttributeError as e:
|
| 121 |
+
print(f"Error loading model: {e}")
|
| 122 |
+
return None
|
| 123 |
|
| 124 |
def generate_text(model, prompt, max_new_tokens=50, temperature=0.8, top_k=40):
|
| 125 |
"""Generate text based on a prompt
|