mo35 commited on
Commit
ee678d5
Β·
1 Parent(s): 1083e33

Fix apply_chat_template return type

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -39,17 +39,27 @@ print("Model ready.")
39
 
40
  # ── Inference ─────────────────────────────────────────────────────────────────
41
  def respond(message: str, history: list) -> str:
42
- messages = [{"role": "user", "content": message}]
 
 
 
43
 
44
- inputs = tokenizer.apply_chat_template(
 
45
  messages,
46
  add_generation_prompt = True,
47
  return_tensors = "pt",
48
- ).to(model.device)
 
 
 
 
 
49
 
50
  with torch.no_grad():
51
  outputs = model.generate(
52
- inputs,
 
53
  max_new_tokens = 1024,
54
  temperature = 0.7,
55
  do_sample = True,
@@ -57,13 +67,14 @@ def respond(message: str, history: list) -> str:
57
  )
58
 
59
  return tokenizer.decode(
60
- outputs[0][inputs.shape[-1]:],
61
  skip_special_tokens = True,
62
  )
63
 
64
  # ── Gradio UI ─────────────────────────────────────────────────────────────────
65
  demo = gr.ChatInterface(
66
  fn = respond,
 
67
  title = "Gemma 4 β€” Quantitative Finance",
68
  description = (
69
  "A specialized AI assistant fine-tuned on quantitative finance: derivatives pricing, "
 
39
 
40
  # ── Inference ─────────────────────────────────────────────────────────────────
41
  def respond(message: str, history: list) -> str:
42
+ messages = []
43
+ for msg in history:
44
+ messages.append({"role": msg["role"], "content": msg["content"]})
45
+ messages.append({"role": "user", "content": message})
46
 
47
+ # apply_chat_template returns BatchEncoding in newer transformers
48
+ encoded = tokenizer.apply_chat_template(
49
  messages,
50
  add_generation_prompt = True,
51
  return_tensors = "pt",
52
+ return_dict = True,
53
+ )
54
+ input_ids = encoded["input_ids"].to(model.device)
55
+ attention_mask = encoded.get("attention_mask", None)
56
+ if attention_mask is not None:
57
+ attention_mask = attention_mask.to(model.device)
58
 
59
  with torch.no_grad():
60
  outputs = model.generate(
61
+ input_ids,
62
+ attention_mask = attention_mask,
63
  max_new_tokens = 1024,
64
  temperature = 0.7,
65
  do_sample = True,
 
67
  )
68
 
69
  return tokenizer.decode(
70
+ outputs[0][input_ids.shape[-1]:],
71
  skip_special_tokens = True,
72
  )
73
 
74
  # ── Gradio UI ─────────────────────────────────────────────────────────────────
75
  demo = gr.ChatInterface(
76
  fn = respond,
77
+ type = "messages",
78
  title = "Gemma 4 β€” Quantitative Finance",
79
  description = (
80
  "A specialized AI assistant fine-tuned on quantitative finance: derivatives pricing, "