Ryan Robson commited on
Commit
0661290
Β·
1 Parent(s): 2f0f602

Fix device mapping for CPU/GPU compatibility

Browse files

- Remove device_map='auto' to avoid meta tensor issues
- Explicitly handle CPU vs GPU device selection
- Use bfloat16 for CPU, float16 for GPU
- Fix device reference in chat function

Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -12,15 +12,21 @@ ADAPTER_MODEL = "robworks-software/ccisd-teks-educator-mistral7b"
12
 
13
  print(f"πŸ“₯ Loading base model: {BASE_MODEL}...")
14
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
 
 
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  BASE_MODEL,
17
- dtype=torch.float16,
18
- device_map="auto",
19
  low_cpu_mem_usage=True
20
  )
21
 
22
  print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
23
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
 
24
 
25
  print("βœ… Model loaded successfully!")
26
 
@@ -47,7 +53,7 @@ def chat(message, history):
47
  prompt += f"[INST] {system_message}\n\n{message} [/INST]"
48
 
49
  # Tokenize
50
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
51
 
52
  # Generate response
53
  with torch.no_grad():
 
12
 
13
  print(f"πŸ“₯ Loading base model: {BASE_MODEL}...")
14
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
15
+
16
+ # Check if GPU is available
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f" Using device: {device}")
19
+
20
+ # Load base model (use bfloat16 for CPU compatibility)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  BASE_MODEL,
23
+ torch_dtype=torch.bfloat16 if device == "cpu" else torch.float16,
 
24
  low_cpu_mem_usage=True
25
  )
26
 
27
  print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
28
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
29
+ model = model.to(device)
30
 
31
  print("βœ… Model loaded successfully!")
32
 
 
53
  prompt += f"[INST] {system_message}\n\n{message} [/INST]"
54
 
55
  # Tokenize
56
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
57
 
58
  # Generate response
59
  with torch.no_grad():