nph4rd commited on
Commit
4e98a92
·
1 Parent(s): 83466a6

use gpu for inf

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -467,7 +467,7 @@ def load_model():
467
  model = AutoModelForCausalLM.from_pretrained(
468
  MODEL_ID,
469
  torch_dtype=torch.float16,
470
- device_map="cpu",
471
  )
472
  model.eval()
473
  return model, tokenizer
@@ -488,7 +488,7 @@ def get_ai_action(game: TinyHanabiGame) -> str:
488
 
489
  # Generate
490
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
491
- inputs = tokenizer(text, return_tensors="pt")
492
 
493
  with torch.no_grad():
494
  outputs = model.generate(
 
467
  model = AutoModelForCausalLM.from_pretrained(
468
  MODEL_ID,
469
  torch_dtype=torch.float16,
470
+ device_map="auto",
471
  )
472
  model.eval()
473
  return model, tokenizer
 
488
 
489
  # Generate
490
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
491
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
492
 
493
  with torch.no_grad():
494
  outputs = model.generate(