Jn-Huang
commited on
Commit
·
1a77428
1
Parent(s):
f6fde6f
Fix bugs: use token param, apply Llama 3.1 chat template, decode only new tokens
Browse files
app.py
CHANGED
|
@@ -37,8 +37,8 @@ def load_model_and_tokenizer():
|
|
| 37 |
|
| 38 |
if USE_PEFT:
|
| 39 |
try:
|
| 40 |
-
_ = PeftConfig.from_pretrained(PEFT_MODEL_ID,
|
| 41 |
-
model = PeftModel.from_pretrained(base, PEFT_MODEL_ID,
|
| 42 |
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
|
| 43 |
return model, tok
|
| 44 |
except Exception as e:
|
|
@@ -51,9 +51,17 @@ DEVICE = model.device
|
|
| 51 |
|
| 52 |
@spaces.GPU
|
| 53 |
@torch.inference_mode()
|
| 54 |
-
def generate_response(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 56 |
enc = {k: v.to(DEVICE) for k, v in enc.items()}
|
|
|
|
|
|
|
| 57 |
out = model.generate(
|
| 58 |
**enc,
|
| 59 |
max_new_tokens=max_new_tokens,
|
|
@@ -62,30 +70,30 @@ def generate_response(prompt: str, max_new_tokens=512, temperature=0.7, top_p=0.
|
|
| 62 |
top_p=top_p,
|
| 63 |
pad_token_id=tokenizer.eos_token_id,
|
| 64 |
)
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
|
| 68 |
-
# Build
|
| 69 |
-
|
| 70 |
if system_prompt:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
if message:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
reply = generate_response(
|
| 81 |
-
|
| 82 |
max_new_tokens=max_new_tokens,
|
| 83 |
temperature=temperature,
|
| 84 |
top_p=top_p,
|
| 85 |
)
|
| 86 |
-
# Strip trailing
|
| 87 |
-
if "assistant:" in reply:
|
| 88 |
-
reply = reply.split("assistant:")[-1].strip()
|
| 89 |
return reply
|
| 90 |
|
| 91 |
demo = gr.ChatInterface(
|
|
|
|
| 37 |
|
| 38 |
if USE_PEFT:
|
| 39 |
try:
|
| 40 |
+
_ = PeftConfig.from_pretrained(PEFT_MODEL_ID, token=HF_TOKEN)
|
| 41 |
+
model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, token=HF_TOKEN)
|
| 42 |
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
|
| 43 |
return model, tok
|
| 44 |
except Exception as e:
|
|
|
|
| 51 |
|
| 52 |
@spaces.GPU
|
| 53 |
@torch.inference_mode()
|
| 54 |
+
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
|
| 55 |
+
# Apply Llama 3.1 chat template
|
| 56 |
+
prompt = tokenizer.apply_chat_template(
|
| 57 |
+
messages,
|
| 58 |
+
tokenize=False,
|
| 59 |
+
add_generation_prompt=True
|
| 60 |
+
)
|
| 61 |
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 62 |
enc = {k: v.to(DEVICE) for k, v in enc.items()}
|
| 63 |
+
|
| 64 |
+
input_length = enc['input_ids'].shape[1]
|
| 65 |
out = model.generate(
|
| 66 |
**enc,
|
| 67 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 70 |
top_p=top_p,
|
| 71 |
pad_token_id=tokenizer.eos_token_id,
|
| 72 |
)
|
| 73 |
+
# Decode only the newly generated tokens
|
| 74 |
+
return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
|
| 75 |
|
| 76 |
def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
|
| 77 |
+
# Build conversation in Llama 3.1 chat format
|
| 78 |
+
messages = []
|
| 79 |
if system_prompt:
|
| 80 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 81 |
+
|
| 82 |
+
for user_msg, assistant_msg in (history or []):
|
| 83 |
+
if user_msg:
|
| 84 |
+
messages.append({"role": "user", "content": user_msg})
|
| 85 |
+
if assistant_msg:
|
| 86 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 87 |
+
|
| 88 |
if message:
|
| 89 |
+
messages.append({"role": "user", "content": message})
|
| 90 |
+
|
| 91 |
reply = generate_response(
|
| 92 |
+
messages,
|
| 93 |
max_new_tokens=max_new_tokens,
|
| 94 |
temperature=temperature,
|
| 95 |
top_p=top_p,
|
| 96 |
)
|
|
|
|
|
|
|
|
|
|
| 97 |
return reply
|
| 98 |
|
| 99 |
demo = gr.ChatInterface(
|