Spaces:
Sleeping
Sleeping
Commit
·
9fcf4ac
1
Parent(s):
a2aa96e
mod app v2
Browse files
app.py
CHANGED
|
@@ -207,16 +207,16 @@ class LanguageModel(nn.Module):
|
|
| 207 |
|
| 208 |
# Load the model
|
| 209 |
model = LanguageModel().to(device)
|
| 210 |
-
model_path = "model_v6_flash_attn.pth"
|
| 211 |
|
| 212 |
# Check if model file exists
|
| 213 |
if os.path.exists(model_path):
|
| 214 |
-
model.load_state_dict(torch.load(
|
|
|
|
| 215 |
model.eval()
|
| 216 |
print("Model loaded successfully")
|
| 217 |
else:
|
| 218 |
-
print(
|
| 219 |
-
model_path}. Please train the model first.")
|
| 220 |
|
| 221 |
# Compile model for better performance
|
| 222 |
model = torch.compile(model)
|
|
|
|
| 207 |
|
| 208 |
# Load the model
|
| 209 |
model = LanguageModel().to(device)
|
| 210 |
+
model_path = "./model_v6_flash_attn.pth"
|
| 211 |
|
| 212 |
# Check if model file exists
|
| 213 |
if os.path.exists(model_path):
|
| 214 |
+
model.load_state_dict(torch.load(
|
| 215 |
+
model_path, map_location=device, weights_only=False))
|
| 216 |
model.eval()
|
| 217 |
print("Model loaded successfully")
|
| 218 |
else:
|
| 219 |
+
print("model file not found")
|
|
|
|
| 220 |
|
| 221 |
# Compile model for better performance
|
| 222 |
model = torch.compile(model)
|