Jn-Huang commited on
Commit
1a77428
·
1 Parent(s): f6fde6f

Fix bugs: use token param, apply Llama 3.1 chat template, decode only new tokens

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -37,8 +37,8 @@ def load_model_and_tokenizer():
37
 
38
  if USE_PEFT:
39
  try:
40
- _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, use_auth_token=HF_TOKEN)
41
- model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, use_auth_token=HF_TOKEN)
42
  print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
43
  return model, tok
44
  except Exception as e:
@@ -51,9 +51,17 @@ DEVICE = model.device
51
 
52
  @spaces.GPU
53
  @torch.inference_mode()
54
- def generate_response(prompt: str, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
 
 
 
 
 
 
55
  enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
56
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
 
 
57
  out = model.generate(
58
  **enc,
59
  max_new_tokens=max_new_tokens,
@@ -62,30 +70,30 @@ def generate_response(prompt: str, max_new_tokens=512, temperature=0.7, top_p=0.
62
  top_p=top_p,
63
  pad_token_id=tokenizer.eos_token_id,
64
  )
65
- return tokenizer.decode(out[0], skip_special_tokens=True)
 
66
 
67
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
68
- # Build a simple conversation string
69
- conv = []
70
  if system_prompt:
71
- conv.append(f"system: {system_prompt}")
72
- for u, a in (history or []):
73
- if u:
74
- conv.append(f"user: {u}")
75
- if a:
76
- conv.append(f"assistant: {a}")
 
 
77
  if message:
78
- conv.append(f"user: {message}")
79
- prompt = "\n".join(conv) + "\nassistant:"
80
  reply = generate_response(
81
- prompt,
82
  max_new_tokens=max_new_tokens,
83
  temperature=temperature,
84
  top_p=top_p,
85
  )
86
- # Strip trailing
87
- if "assistant:" in reply:
88
- reply = reply.split("assistant:")[-1].strip()
89
  return reply
90
 
91
  demo = gr.ChatInterface(
 
37
 
38
  if USE_PEFT:
39
  try:
40
+ _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, token=HF_TOKEN)
41
+ model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, token=HF_TOKEN)
42
  print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
43
  return model, tok
44
  except Exception as e:
 
51
 
52
  @spaces.GPU
53
  @torch.inference_mode()
54
+ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
55
+ # Apply Llama 3.1 chat template
56
+ prompt = tokenizer.apply_chat_template(
57
+ messages,
58
+ tokenize=False,
59
+ add_generation_prompt=True
60
+ )
61
  enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
62
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
63
+
64
+ input_length = enc['input_ids'].shape[1]
65
  out = model.generate(
66
  **enc,
67
  max_new_tokens=max_new_tokens,
 
70
  top_p=top_p,
71
  pad_token_id=tokenizer.eos_token_id,
72
  )
73
+ # Decode only the newly generated tokens
74
+ return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
75
 
76
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
77
+ # Build conversation in Llama 3.1 chat format
78
+ messages = []
79
  if system_prompt:
80
+ messages.append({"role": "system", "content": system_prompt})
81
+
82
+ for user_msg, assistant_msg in (history or []):
83
+ if user_msg:
84
+ messages.append({"role": "user", "content": user_msg})
85
+ if assistant_msg:
86
+ messages.append({"role": "assistant", "content": assistant_msg})
87
+
88
  if message:
89
+ messages.append({"role": "user", "content": message})
90
+
91
  reply = generate_response(
92
+ messages,
93
  max_new_tokens=max_new_tokens,
94
  temperature=temperature,
95
  top_p=top_p,
96
  )
 
 
 
97
  return reply
98
 
99
  demo = gr.ChatInterface(