dipta007 commited on
Commit
df8336a
·
verified ·
1 Parent(s): 257b03c

Update README

Browse files
Files changed (1) hide show
  1. README.md +45 -16
README.md CHANGED
@@ -85,6 +85,8 @@ GRPO is supervised with a sum of seven rewards, grouped into three families:
85
 
86
  ## Quickstart
87
 
 
 
88
  ```python
89
  from transformers import AutoModelForCausalLM, AutoTokenizer
90
 
@@ -97,16 +99,8 @@ model = AutoModelForCausalLM.from_pretrained(
97
  device_map="auto",
98
  )
99
 
100
- evidence_doc = (
101
- "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
102
- "France. It is named after the engineer Gustave Eiffel, whose company designed and "
103
- "built the tower from 1887 to 1889. Locally nicknamed 'La dame de fer', it was "
104
- "constructed as the centerpiece of the 1889 World's Fair. The tower is 330 metres "
105
- "(1,083 ft) tall."
106
- )
107
- claim = "The Eiffel Tower was completed in 1887 and stands 330 metres tall."
108
 
109
- user_prompt = f"""You are tasked with systematically verifying the accuracy of a claim. You will be provided with a claim to verify and an evidence document to consult.
110
 
111
  Here is the evidence document you should consult:
112
 
@@ -130,13 +124,37 @@ Stop immediately after the closing </verification> tag.
130
 
131
  Begin your verification process now."""
132
 
133
- messages = [{"role": "user", "content": user_prompt}]
134
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
- inputs = tokenizer([text], return_tensors="pt").to(model.device)
136
 
137
- # max_new_tokens matches training-time max_completion_length
138
- out = model.generate(**inputs, max_new_tokens=4500, temperature=0.7, do_sample=True)
139
- response = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  print(response)
141
  ```
142
 
@@ -154,9 +172,20 @@ def parse_trace(text: str):
154
  return [(tag, body.strip()) for tag, body in TAG_RE.findall(text)]
155
 
156
  def pretty_print(text: str) -> None:
 
 
 
 
 
 
 
 
 
 
 
157
  cycle_idx = 0
158
  pending_q = None
159
- for tag, body in parse_trace(text):
160
  if tag == "think":
161
  print("─" * 78)
162
  print("🧠 THINK")
 
85
 
86
  ## Quickstart
87
 
88
+ DecomposeRL expects a specific verification prompt around your `claim` + `evidence_doc`. The `build_prompt` helper below wraps them for you so you don't have to construct the full instruction block every time.
89
+
90
  ```python
91
  from transformers import AutoModelForCausalLM, AutoTokenizer
92
 
 
99
  device_map="auto",
100
  )
101
 
 
 
 
 
 
 
 
 
102
 
103
+ PROMPT_TEMPLATE = """You are tasked with systematically verifying the accuracy of a claim. You will be provided with a claim to verify and an evidence document to consult.
104
 
105
  Here is the evidence document you should consult:
106
 
 
124
 
125
  Begin your verification process now."""
126
 
 
 
 
127
 
128
+ def build_prompt(claim: str, evidence_doc: str) -> str:
129
+ """Wrap a claim + evidence document in the DecomposeRL verification prompt."""
130
+ return PROMPT_TEMPLATE.format(claim=claim, evidence_doc=evidence_doc)
131
+
132
+
133
+ def verify(claim: str, evidence_doc: str, max_new_tokens: int = 4500, temperature: float = 0.7) -> str:
134
+ """Run the model end-to-end on a (claim, evidence_doc) pair and return the raw trace."""
135
+ messages = [{"role": "user", "content": build_prompt(claim, evidence_doc)}]
136
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
137
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
138
+ out = model.generate(
139
+ **inputs,
140
+ max_new_tokens=max_new_tokens, # matches training-time max_completion_length
141
+ temperature=temperature,
142
+ do_sample=True,
143
+ )
144
+ return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
145
+
146
+
147
+ # Usage
148
+ evidence_doc = (
149
+ "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
150
+ "France. It is named after the engineer Gustave Eiffel, whose company designed and "
151
+ "built the tower from 1887 to 1889. Locally nicknamed 'La dame de fer', it was "
152
+ "constructed as the centerpiece of the 1889 World's Fair. The tower is 330 metres "
153
+ "(1,083 ft) tall."
154
+ )
155
+ claim = "The Eiffel Tower was completed in 1887 and stands 330 metres tall."
156
+
157
+ response = verify(claim, evidence_doc)
158
  print(response)
159
  ```
160
 
 
172
  return [(tag, body.strip()) for tag, body in TAG_RE.findall(text)]
173
 
174
  def pretty_print(text: str) -> None:
175
+ parsed = parse_trace(text)
176
+ tags = {tag for tag, _ in parsed}
177
+ if not parsed or "verification" not in tags:
178
+ print("⚠️ Could not parse output into the expected "
179
+ "think/question/answer/verification structure.")
180
+ print("Raw output:")
181
+ print("─" * 78)
182
+ print(text)
183
+ print("─" * 78)
184
+ return
185
+
186
  cycle_idx = 0
187
  pending_q = None
188
+ for tag, body in parsed:
189
  if tag == "think":
190
  print("─" * 78)
191
  print("🧠 THINK")