Spaces:
Sleeping
Sleeping
Mustafa Tag Eldeen commited on
Commit ·
80dbe0c
1
Parent(s): 075304e
swap: Qwen2.5-0.5B-Instruct (faster, reliable #### format) + fix dp2 detection bug
Browse files- app.py +23 -18
- learning-records/0003-qwen-swap.md +21 -0
- src/inference/complete_pipeline.py +2 -1
app.py
CHANGED
|
@@ -26,10 +26,10 @@ from src.features.windows import (
|
|
| 26 |
compute_step_boundaries,
|
| 27 |
)
|
| 28 |
from src.features.span_detection import extract_answer_after_hash
|
| 29 |
-
from src.utils import
|
| 30 |
|
| 31 |
-
MODEL_ID = os.environ.get("MODEL_ID", "
|
| 32 |
-
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "
|
| 33 |
|
| 34 |
model = None
|
| 35 |
tokenizer = None
|
|
@@ -70,20 +70,12 @@ def load_model():
|
|
| 70 |
tokenizer.pad_token = tokenizer.eos_token
|
| 71 |
tokenizer.padding_side = "left"
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
step_token_ids.add(tid)
|
| 80 |
-
except (IndexError, ValueError):
|
| 81 |
-
pass
|
| 82 |
-
# Also check all tokens in the vocab that decode to "Step"
|
| 83 |
-
for tid in range(tokenizer.vocab_size):
|
| 84 |
-
if tokenizer.decode([tid]).strip() == "Step":
|
| 85 |
-
step_token_ids.add(tid)
|
| 86 |
-
step_token_id = min(step_token_ids)
|
| 87 |
|
| 88 |
model = AutoModelForCausalLM.from_pretrained(
|
| 89 |
MODEL_ID,
|
|
@@ -104,7 +96,20 @@ def generate_reasoning(question: str, gold_answer: str, max_new_tokens: int):
|
|
| 104 |
if not loaded:
|
| 105 |
return None, "Failed to load model."
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
inputs = tokenizer(prompt, return_tensors="pt", padding=False)
|
| 109 |
input_ids = inputs["input_ids"]
|
| 110 |
attention_mask = inputs["attention_mask"]
|
|
|
|
| 26 |
compute_step_boundaries,
|
| 27 |
)
|
| 28 |
from src.features.span_detection import extract_answer_after_hash
|
| 29 |
+
from src.utils import answers_match
|
| 30 |
|
| 31 |
+
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 32 |
+
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "128"))
|
| 33 |
|
| 34 |
model = None
|
| 35 |
tokenizer = None
|
|
|
|
| 70 |
tokenizer.pad_token = tokenizer.eos_token
|
| 71 |
tokenizer.padding_side = "left"
|
| 72 |
|
| 73 |
+
# Qwen's tokenizer has a stable "Step" token ID (8304) across contexts
|
| 74 |
+
# Use the encoded token that results from "Step" at the start of generation
|
| 75 |
+
suffix = "\n<|im_start|>assistant\n"
|
| 76 |
+
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
|
| 77 |
+
step_token_id = tokenizer.encode(suffix + "Step", add_special_tokens=False)[len(suffix_ids)]
|
| 78 |
+
step_token_ids = {step_token_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
MODEL_ID,
|
|
|
|
| 96 |
if not loaded:
|
| 97 |
return None, "Failed to load model."
|
| 98 |
|
| 99 |
+
# Build chat messages with the paper's "cot" template content
|
| 100 |
+
system_msg = (
|
| 101 |
+
'You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ". '
|
| 102 |
+
'Always provide your final answer after #### at the end.'
|
| 103 |
+
)
|
| 104 |
+
user_msg = (
|
| 105 |
+
f'Please solve this step by step, putting each step after "Step [step_number]: " '
|
| 106 |
+
f'and always provide your final answer after ####.\n\nQuestion: {question}'
|
| 107 |
+
)
|
| 108 |
+
messages = [
|
| 109 |
+
{"role": "system", "content": system_msg},
|
| 110 |
+
{"role": "user", "content": user_msg},
|
| 111 |
+
]
|
| 112 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 113 |
inputs = tokenizer(prompt, return_tensors="pt", padding=False)
|
| 114 |
input_ids = inputs["input_ids"]
|
| 115 |
attention_mask = inputs["attention_mask"]
|
learning-records/0003-qwen-swap.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 0003: Swapped to Qwen2.5-0.5B-Instruct
|
| 2 |
+
|
| 3 |
+
## Why
|
| 4 |
+
TinyLlama-1.1B-Chat couldn't follow the `####` format with the base-model style prompt (`"Solution:\n\n"`), causing answer extraction and dp2 detection to fail. Qwen2.5-0.5B-Instruct is:
|
| 5 |
+
- **494M params** vs 1.1B → ~2x faster on CPU (est. 5-10 tok/s)
|
| 6 |
+
- **Excellent instruction following** → actually outputs `####`
|
| 7 |
+
- **Stable token IDs** — "Step" is always 8304 (no context-dependence)
|
| 8 |
+
|
| 9 |
+
## Changes
|
| 10 |
+
|
| 11 |
+
### `app.py`
|
| 12 |
+
- `MODEL_ID` → `Qwen/Qwen2.5-0.5B-Instruct`
|
| 13 |
+
- `MAX_NEW_TOKENS` → 128 (smaller model needs fewer tokens)
|
| 14 |
+
- Replaced `format_prompt(question, "cot")` with `tokenizer.apply_chat_template()` — paper's prompt content preserved verbatim, just wrapped in `<|im_start|>` format
|
| 15 |
+
- Simplified `step_token_id` detection — uses context-aware encoding with `\n<|im_start|>assistant\n` suffix (single stable ID)
|
| 16 |
+
|
| 17 |
+
### `complete_pipeline.py`
|
| 18 |
+
- Fixed dp2 detection: only runs `detect_dp2_index` when `####` is actually in the generated text — prevents fallback numbers (e.g. "2" from "$2") from creating a wrong dp2_idx
|
| 19 |
+
|
| 20 |
+
## Key insight
|
| 21 |
+
Instruct/chat models need their native chat template format, even when the content is from a base-model paper. The `apply_chat_template()` wrapper preserves the research while making the model understand the instructions.
|
src/inference/complete_pipeline.py
CHANGED
|
@@ -134,7 +134,8 @@ def process_complete_generation(
|
|
| 134 |
output.produced_answer = produced_answer
|
| 135 |
|
| 136 |
# Detect dp2 (start of extracted answer in token sequence)
|
| 137 |
-
|
|
|
|
| 138 |
output.dp2_idx = detect_dp2_index(
|
| 139 |
output.full_seq_ids,
|
| 140 |
tokenizer,
|
|
|
|
| 134 |
output.produced_answer = produced_answer
|
| 135 |
|
| 136 |
# Detect dp2 (start of extracted answer in token sequence)
|
| 137 |
+
# Only when #### is present — otherwise fallback numbers mislead dp2 detection
|
| 138 |
+
if produced_answer and "####" in (output.produced_text or ""):
|
| 139 |
output.dp2_idx = detect_dp2_index(
|
| 140 |
output.full_seq_ids,
|
| 141 |
tokenizer,
|