Spaces:
Sleeping
Sleeping
scott-ashton-tds commited on
Commit ·
e5f9099
1
Parent(s): 5bb488e
Fix model loading with proper torch_dtype and device_map
Browse files
app.py
CHANGED
|
@@ -26,8 +26,13 @@ MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000"))
|
|
| 26 |
|
| 27 |
print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
|
| 28 |
|
| 29 |
-
starvector = StarVectorForCausalLM.from_pretrained(
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
starvector.eval()
|
| 32 |
|
| 33 |
|
|
|
|
| 26 |
|
| 27 |
print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
|
| 28 |
|
| 29 |
+
starvector = StarVectorForCausalLM.from_pretrained(
|
| 30 |
+
MODEL_NAME,
|
| 31 |
+
torch_dtype=DTYPE,
|
| 32 |
+
device_map="auto" if DEVICE == "cuda" else None
|
| 33 |
+
)
|
| 34 |
+
if DEVICE != "cuda":
|
| 35 |
+
starvector = starvector.to(device=DEVICE)
|
| 36 |
starvector.eval()
|
| 37 |
|
| 38 |
|