scott-ashton-tds commited on
Commit
e5f9099
·
1 Parent(s): 5bb488e

Fix model loading with proper torch_dtype and device_map

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -26,8 +26,13 @@ MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000"))
26
 
27
  print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
28
 
29
- starvector = StarVectorForCausalLM.from_pretrained(MODEL_NAME)
30
- starvector = starvector.to(device=DEVICE)
 
 
 
 
 
31
  starvector.eval()
32
 
33
 
 
26
 
27
  print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
28
 
29
+ starvector = StarVectorForCausalLM.from_pretrained(
30
+ MODEL_NAME,
31
+ torch_dtype=DTYPE,
32
+ device_map="auto" if DEVICE == "cuda" else None
33
+ )
34
+ if DEVICE != "cuda":
35
+ starvector = starvector.to(device=DEVICE)
36
  starvector.eval()
37
 
38