Pabloler21 commited on
Commit
cf33c28
·
1 Parent(s): d8d8506

Fix: handle BatchEncoding from apply_chat_template in transformers 5.x

Browse files
Files changed (1) hide show
  1. engine.py +11 -4
engine.py CHANGED
@@ -26,28 +26,35 @@ if IS_SPACE:
26
 
27
  _MODEL_ID = "Qwen/Qwen3-8B"
28
  _tokenizer = AutoTokenizer.from_pretrained(_MODEL_ID)
29
- _model = AutoModelForCausalLM.from_pretrained(_MODEL_ID, torch_dtype=torch.bfloat16)
30
 
31
  @spaces.GPU(duration=90)
32
  def run_turn(chat_messages, user_msg, gen_max_tokens=300, extract_max_tokens=120):
33
  _model.to("cuda")
34
 
35
  def _generate(messages, max_tokens, temperature):
36
- ids = _tokenizer.apply_chat_template(
 
37
  messages,
38
  add_generation_prompt=True,
39
  enable_thinking=False,
40
  return_tensors="pt",
41
  ).to("cuda")
 
 
 
 
 
 
42
  with torch.no_grad():
43
  out = _model.generate(
44
- ids,
45
  max_new_tokens=max_tokens,
46
  do_sample=temperature > 0,
47
  temperature=temperature if temperature > 0 else 1.0,
48
  pad_token_id=_tokenizer.eos_token_id,
49
  )
50
- return _tokenizer.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
51
 
52
  reply = _generate(chat_messages, gen_max_tokens, temperature=0.8)
53
  raw_json = _generate(_build_extract_messages(user_msg, reply), extract_max_tokens, temperature=0.0)
 
26
 
27
  _MODEL_ID = "Qwen/Qwen3-8B"
28
  _tokenizer = AutoTokenizer.from_pretrained(_MODEL_ID)
29
+ _model = AutoModelForCausalLM.from_pretrained(_MODEL_ID, dtype=torch.bfloat16)
30
 
31
  @spaces.GPU(duration=90)
32
  def run_turn(chat_messages, user_msg, gen_max_tokens=300, extract_max_tokens=120):
33
  _model.to("cuda")
34
 
35
  def _generate(messages, max_tokens, temperature):
36
+ # transformers 5.x returns BatchEncoding, not a plain tensor
37
+ tokenized = _tokenizer.apply_chat_template(
38
  messages,
39
  add_generation_prompt=True,
40
  enable_thinking=False,
41
  return_tensors="pt",
42
  ).to("cuda")
43
+ if hasattr(tokenized, "input_ids"):
44
+ input_ids = tokenized["input_ids"]
45
+ generate_kwargs = dict(tokenized)
46
+ else:
47
+ input_ids = tokenized
48
+ generate_kwargs = {"input_ids": tokenized}
49
  with torch.no_grad():
50
  out = _model.generate(
51
+ **generate_kwargs,
52
  max_new_tokens=max_tokens,
53
  do_sample=temperature > 0,
54
  temperature=temperature if temperature > 0 else 1.0,
55
  pad_token_id=_tokenizer.eos_token_id,
56
  )
57
+ return _tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
58
 
59
  reply = _generate(chat_messages, gen_max_tokens, temperature=0.8)
60
  raw_json = _generate(_build_extract_messages(user_msg, reply), extract_max_tokens, temperature=0.0)