Commit ·
6dbc3a8
1
Parent(s): 86e2109
add boss token cutoff
Browse files
app.py
CHANGED
|
@@ -17,7 +17,16 @@ def generate_response(prompt):
|
|
| 17 |
outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
|
| 18 |
input_length = inputs['input_ids'].shape[1]
|
| 19 |
new_token_ids = outputs[0][input_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
|
|
|
| 21 |
return new_tokens
|
| 22 |
|
| 23 |
iface = gr.Interface(
|
|
|
|
| 17 |
outputs = model.generate(**inputs, max_new_tokens=5, temperature=1.0)
|
| 18 |
input_length = inputs['input_ids'].shape[1]
|
| 19 |
new_token_ids = outputs[0][input_length:]
|
| 20 |
+
bos_token_id = tokenizer.bos_token_id
|
| 21 |
+
if bos_token_id is not None:
|
| 22 |
+
bos_positions = (new_token_ids == bos_token_id).nonzero(as_tuple=True)[0]
|
| 23 |
+
if len(bos_positions) > 0:
|
| 24 |
+
# Truncate at first BOS token
|
| 25 |
+
first_bos_pos = bos_positions[0].item()
|
| 26 |
+
new_token_ids = new_token_ids[:first_bos_pos]
|
| 27 |
+
|
| 28 |
new_tokens = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
| 29 |
+
|
| 30 |
return new_tokens
|
| 31 |
|
| 32 |
iface = gr.Interface(
|