ayjays132 commited on
Commit
828f04e
·
verified ·
1 Parent(s): 1c116cd

Upload 2 files

Browse files
Files changed (2) hide show
  1. chain_of_thought_gui.py +748 -99
  2. chain_of_thought_wrapper.py +570 -93
chain_of_thought_gui.py CHANGED
@@ -1,119 +1,768 @@
1
  #!/usr/bin/env python3
2
  """
3
- NeuroReasoner 1 ChainofThought GUI
4
  -------------------------------------------------------------
5
- A futuristic, user‑friendly Streamlit app for stepbystep reasoning
6
- using any Hugging Face causal LM.
7
-
8
- Features:
9
- • Load any model by repo name or local path
10
- • Full control of generation params (Temp, top‑k/p, etc.)
11
- • Self‑Consistency sampling
12
- • ASCII telemetry panels
13
- • Progress indicators and collapsible reasoning details
14
  """
15
  import os
16
  import time
17
- import torch
18
- import pynvml
19
  import streamlit as st
20
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
21
- from chain_of_thought_wrapper import ChainOfThoughtWrapper
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Initialize GPU telemetry
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
  pynvml.nvmlInit()
26
  GPU_AVAILABLE = True
27
  except Exception:
28
  GPU_AVAILABLE = False
29
 
30
- @st.cache_data(show_spinner=False)
31
- def get_telemetry():
 
 
 
 
32
  if not GPU_AVAILABLE or not torch.cuda.is_available():
33
- return "[No GPU telemetry]"
34
- handle = pynvml.nvmlDeviceGetHandleByIndex(0)
35
- util = pynvml.nvmlDeviceGetUtilizationRates(handle)
36
- mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
37
- return f"GPU: {util.gpu}% | Mem: {mem.used//1024**2}/{mem.total//1024**2} MB"
38
-
39
- # Sidebar configuration
40
- st.sidebar.title("⚙️ Configuration")
41
- model_name = st.sidebar.text_input(
42
- "Model (HuggingFace repo or local path)", value="ayjays132/NeuroReasoner-1-NR-1"
43
- )
44
- device = st.sidebar.selectbox("Device", options=["cuda" if torch.cuda.is_available() else "cpu", "cpu"] )
45
- num_sequences = st.sidebar.slider("# Chains", min_value=1, max_value=10, value=3)
46
- self_consistency = st.sidebar.checkbox("Self‑Consistency", value=False)
47
- max_new_tokens = st.sidebar.slider("Max New Tokens", 50, 1024, 256)
48
- temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
49
- top_k = st.sidebar.slider("Top-k", 0, 200, 50)
50
- top_p = st.sidebar.slider("Top-p", 0.0, 1.0, 0.9)
51
- no_repeat_ngram = st.sidebar.slider("No‑repeat ngram", 0, 10, 3)
52
-
53
- # Main interface
54
- st.markdown("# 🌀 NeuroReasoner CoT GUI")
55
- col1, col2 = st.columns([3,1])
56
- with col1:
57
- prompt = st.text_area("🚀 Enter your prompt", value="Explain why the sky is blue.", height=120)
58
- with col2:
59
- st.metric("Telemetry", get_telemetry())
60
-
61
- if st.button("🪄 Generate Reasoning", type="primary"):
62
- if not prompt.strip():
63
- st.error("Please enter a prompt.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  st.stop()
65
- # Load model & tokenizer
 
66
  try:
67
- with st.spinner("🌐 Loading model and tokenizer..."):
68
- tokenizer = AutoTokenizer.from_pretrained(model_name)
69
- model = AutoModelForCausalLM.from_pretrained(model_name)
70
- model.to(device)
71
- st.success("✅ Model loaded.")
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
- st.error(f"❌ Load error: {e}")
 
74
  st.stop()
75
- # Setup CoT
76
- cfg = GenerationConfig(
77
- max_new_tokens=max_new_tokens,
78
- temperature=temperature,
79
- top_k=top_k,
80
- top_p=top_p,
81
- do_sample=True,
82
- num_return_sequences=(num_sequences if self_consistency else 1),
83
- no_repeat_ngram_size=no_repeat_ngram,
84
- eos_token_id=tokenizer.eos_token_id,
85
- pad_token_id=tokenizer.pad_token_id
86
- )
87
- cot = ChainOfThoughtWrapper(
88
- model=model,
89
- tokenizer=tokenizer,
90
- generation_config=cfg,
91
- device=device,
92
- self_consistency=self_consistency,
93
- consistency_rounds=(num_sequences if self_consistency else 1)
94
- )
95
- # Tokenize & generate
96
- inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
97
- start = time.time()
98
- output = cot.generate(
99
- input_ids=inputs['input_ids'],
100
- attention_mask=inputs['attention_mask'],
101
- num_return_sequences=(num_sequences if self_consistency else 1)
102
- )
103
- elapsed = time.time() - start
104
- st.success(f"✨ Done in {elapsed:.2f}s")
105
- # Display results
106
- for idx, (full, steps, ans) in enumerate(zip(output['full_texts'], output['reasoning_steps'], output['final_answers']), 1):
107
- with st.expander(f"Chain {idx}"):
108
- st.text_area("Full Text", value=full, height=200)
109
- if steps:
110
- st.write("**Steps:**")
111
- for i, s in enumerate(steps, 1): st.write(f"{i}. {s}")
112
- else:
113
- st.warning("No parsed steps.")
114
- st.markdown(f"**Final Answer:** {ans}")
115
- st.markdown("---")
116
- st.write(f"Telemetry: {get_telemetry()}")
117
 
118
- # Footer
119
- st.markdown("<sub>Built for a futuristic, seamless reasoning experience.</sub>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ NeuroReasoner Chain-of-Thought GUI (Dark Theme Enhanced)
4
  -------------------------------------------------------------
5
+ A premium Streamlit app for step-by-step reasoning
6
+ across any Hugging Face model (causal or seq2seq).
7
+ Featuring a dark theme, model-type detection, self-consistency
8
+ sampling, and robust handling.
 
 
 
 
 
9
  """
10
  import os
11
  import time
 
 
12
  import streamlit as st
13
+ import torch
14
+ import pynvml # For GPU telemetry
15
+ import numpy as np
16
+ from transformers import (
17
+ AutoConfig,
18
+ AutoTokenizer,
19
+ AutoModelForCausalLM,
20
+ AutoModelForSeq2SeqLM,
21
+ GenerationConfig,
22
+ PretrainedConfig
23
+ )
24
+ from collections import Counter # For self-consistency voting
25
+ import gc # Import garbage collector
26
 
27
+ # Assuming chain_of_thought_wrapper.py is in the same directory
28
+ # and is designed to work with standard Hugging Face models and GenerationConfig.
29
+ # Make sure the wrapper correctly handles num_return_sequences for CoT and SC,
30
+ # and returns the expected dictionary structure:
31
+ # {'full_texts': [...], 'reasoning_steps': [...], 'final_answers': [...], 'consensus_answer': '...'}
32
+ try:
33
+ from chain_of_thought_wrapper import ChainOfThoughtWrapper
34
+ except ImportError:
35
+ st.error("Error: chain_of_thought_wrapper.py not found. Please ensure it's in the same directory.")
36
+ st.stop()
37
+
38
+
39
+ # --- Page Configuration ---
40
+ st.set_page_config(
41
+ page_title="🧠 NeuroReasoner CoT GUI",
42
+ page_icon="🧠",
43
+ layout="wide",
44
+ initial_sidebar_state="expanded",
45
+ menu_items={
46
+ 'Get Help': 'https://github.com/your_repo_link_here', # Replace or remove
47
+ 'Report a bug': "https://github.com/your_repo_link_here/issues", # Replace or remove
48
+ 'About': """
49
+ **NeuroReasoner Chain-of-Thought GUI**
50
+ An open-source interface powered by Hugging Face models and the NeuroReasoner wrapper.
51
+ Explore step-by-step reasoning with various language models.
52
+ """
53
+ }
54
+ )
55
+
56
+ # --- Dark Theme CSS ---
57
+ st.markdown("""
58
+ <style>
59
+ /* Overall Page Background & Text (Dark Theme) */
60
+ body {
61
+ background-color: #1E1E1E; /* Dark grey background */
62
+ color: #D4D4D4; /* Light grey text */
63
+ font-family: 'Segoe UI', Roboto, Arial, sans-serif;
64
+ }
65
+ .stApp {
66
+ background-color: #1E1E1E;
67
+ color: #D4D4D4;
68
+ }
69
+
70
+ /* Sidebar Styling */
71
+ .stSidebar {
72
+ background-color: #2D2D2D; /* Slightly lighter dark grey for sidebar */
73
+ padding: 2rem 1rem;
74
+ border-right: 1px solid #3E3E3E; /* Subtle border */
75
+ }
76
+ .stSidebar h1, .stSidebar h2, .stSidebar h3 {
77
+ color: #569CD6; /* Visual Studio Code blue for sidebar headers */
78
+ }
79
+ .stSidebar label {
80
+ color: #D4D4D4 !important; /* Ensure sidebar labels are visible */
81
+ }
82
+
83
+
84
+ /* Main Content Area */
85
+ .stContainer {
86
+ padding: 2rem;
87
+ }
88
+
89
+ /* Titles and Headers */
90
+ h1, h2, h3, h4, h5, h6 {
91
+ color: #569CD6; /* VS Code blue headings */
92
+ margin-top: 1rem;
93
+ margin-bottom: 0.8rem;
94
+ }
95
+ h1 { font-size: 2.5rem; color: #4EC9B0; } /* Teal for main title */
96
+ h2 { font-size: 2rem; border-bottom: 2px solid #569CD6; padding-bottom: 0.5rem; margin-bottom: 1rem;}
97
+
98
+
99
+ /* Buttons */
100
+ .stButton>button {
101
+ background-color: #1E4D2B; /* Dark green */
102
+ color: #4EC9B0; /* Teal text */
103
+ border: none;
104
+ border-radius: 0.5rem;
105
+ padding: 0.75rem 1.5rem;
106
+ font-size: 1rem;
107
+ font-weight: bold;
108
+ transition: background-color 0.2s ease, transform 0.1s ease;
109
+ box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
110
+ }
111
+ .stButton>button:hover {
112
+ background-color: #27633A; /* Lighter green on hover */
113
+ transform: translateY(-1px);
114
+ }
115
+ .stButton>button:active {
116
+ background-color: #1A3C23; /* Darker green on click */
117
+ transform: translateY(0);
118
+ box-shadow: 1px 1px 3px rgba(0, 0, 0, 0.4);
119
+ }
120
+
121
+
122
+ /* Text areas and inputs */
123
+ .stTextArea textarea, .stTextInput input {
124
+ border: 1px solid #3E3E3E; /* Dark border */
125
+ border-radius: 0.4rem;
126
+ padding: 0.75rem;
127
+ font-size: 1rem;
128
+ background-color: #252526; /* VS Code background */
129
+ color: #D4D4D4; /* Light text */
130
+ box-shadow: inset 1px 1px 3px rgba(0, 0, 0, 0.2);
131
+ }
132
+ .stTextArea label, .stTextInput label {
133
+ font-weight: bold;
134
+ color: #9CDCFE !important; /* Light blue labels */
135
+ margin-bottom: 0.5rem;
136
+ display: block;
137
+ }
138
+ /* Streamlit status box styling */
139
+ .st-emotion-cache-vj1l9j { /* Target the status box content div */
140
+ background-color: #2D2D2D; /* Match sidebar background */
141
+ border: 1px solid #3E3E3E;
142
+ border-radius: 0.5rem;
143
+ padding: 1rem;
144
+ margin-bottom: 1rem;
145
+ }
146
+ .st-emotion-cache-vj1l9j .stMarkdown p { /* Style text inside status */
147
+ color: #D4D4D4 !important;
148
+ }
149
+ /* Status box icons/text (might need to target specific internal classes) */
150
+ .st-emotion-cache-vj1l9j .stAlert {
151
+ background-color: transparent !important; /* Don't want alert backgrounds inside status */
152
+ }
153
+
154
+
155
+ /* Info/Success/Error/Warning boxes */
156
+ .stAlert {
157
+ border-radius: 0.5rem;
158
+ margin-bottom: 1rem;
159
+ padding: 1rem;
160
+ font-size: 1rem;
161
+ border-left: 5px solid transparent; /* Base style */
162
+ }
163
+ .stAlert.stAlert-info { border-left-color: #569CD6; background-color: #2A3E52; color: #9CDCFE; } /* Dark blue info */
164
+ .stAlert.stAlert-success { border-left-color: #4EC9B0; background-color: #28403A; color: #7AC7A3; } /* Dark teal success */
165
+ .stAlert.stAlert-warning { border-left-color: #DCDCAA; background-color: #454032; color: #FFDAA6; } /* Dark yellow warning */
166
+ .stAlert.stAlert-error { border-left-color: #F44747; background-color: #4A3030; color: #F48787; } /* Dark red error */
167
+
168
+
169
+ /* Expander styling */
170
+ .streamlit-expanderHeader {
171
+ background-color: #3E3E3E; /* Dark grey header */
172
+ color: #D4D4D4; /* Light grey text */
173
+ border-radius: 0.5rem;
174
+ padding: 0.75rem 1.2rem;
175
+ margin-top: 0.8rem;
176
+ margin-bottom: 0.5rem;
177
+ font-weight: bold;
178
+ font-size: 1.1rem;
179
+ cursor: pointer;
180
+ transition: background-color 0.2s ease;
181
+ }
182
+ .streamlit-expanderHeader:hover {
183
+ background-color: #4E4E4E; /* Slightly lighter on hover */
184
+ }
185
+ .streamlit-expanderContent {
186
+ background-color: #252526; /* VS Code background */
187
+ border: 1px solid #3E3E3E;
188
+ border-top: none;
189
+ border-bottom-left-radius: 0.5rem;
190
+ border-bottom-right-radius: 0.5rem;
191
+ padding: 1.5rem;
192
+ margin-top: 0;
193
+ color: #D4D4D4;
194
+ }
195
+
196
+ /* Labels for the output text areas */
197
+ .output-label {
198
+ font-weight: bold !important;
199
+ color: #9CDCFE !important; /* Light blue */
200
+ margin-top: 1rem;
201
+ margin-bottom: 0.5rem;
202
+ display: block;
203
+ font-size: 1.1rem;
204
+ }
205
+
206
+ /* Custom class for output text areas to differentiate from input */
207
+ .output-text-area textarea {
208
+ background-color: #1E1E1E; /* Even darker background for outputs */
209
+ border: 1px solid #3E3E3E;
210
+ border-radius: 0.4rem;
211
+ padding: 0.75rem;
212
+ font-size: 1rem;
213
+ color: #D4D4D4;
214
+ }
215
+
216
+ /* Telemetry box styling */
217
+ .telemetry-box {
218
+ background-color: #2D2D2D; /* Match sidebar */
219
+ border: 1px solid #3E3E3E;
220
+ border-radius: 0.5rem;
221
+ padding: 0.75rem;
222
+ margin-top: 1rem;
223
+ font-size: 0.9rem;
224
+ color: #D4D4D4;
225
+ text-align: center;
226
+ }
227
+
228
+ /* Self-Consistency Consensus Styling */
229
+ .consensus-answer {
230
+ background-color: #28403A; /* Dark green */
231
+ color: #7AC7A3; /* Light green text */
232
+ border: 1px solid #3A5048;
233
+ border-radius: 0.5rem;
234
+ padding: 1rem;
235
+ margin-top: 1rem;
236
+ margin-bottom: 1rem;
237
+ font-size: 1.2rem;
238
+ font-weight: bold;
239
+ }
240
+ .consensus-answer strong {
241
+ color: #4EC9B0; /* Teal for "Consensus Answer" label */
242
+ }
243
+ .consensus-answer div {
244
+ color: #D4D4D4; /* Ensure the answer text is light */
245
+ }
246
+
247
+
248
+ </style>
249
+ """, unsafe_allow_html=True)
250
+
251
+
252
+ # --- GPU Telemetry Setup ---
253
  try:
254
  pynvml.nvmlInit()
255
  GPU_AVAILABLE = True
256
  except Exception:
257
  GPU_AVAILABLE = False
258
 
259
+ # Use st.empty to hold the telemetry status text, defined *outside* cached functions
260
+ telemetry_placeholder = st.empty()
261
+
262
+ def update_telemetry():
263
+ """Updates the telemetry display in the dedicated placeholder."""
264
+ telemetry_text = "[Checking System Status...]"
265
  if not GPU_AVAILABLE or not torch.cuda.is_available():
266
+ telemetry_text = "📊 System Status: [No GPU Available]"
267
+ else:
268
+ try:
269
+ h = pynvml.nvmlDeviceGetHandleByIndex(0)
270
+ u = pynvml.nvmlDeviceGetUtilizationRates(h)
271
+ m = pynvml.nvmlDeviceGetMemoryInfo(h)
272
+ mem_used_mb = m.used // 1024**2
273
+ mem_total_mb = m.total // 1024**2
274
+ telemetry_text = f"📊 System Status: GPU {u.gpu}% | Mem {mem_used_mb}/{mem_total_mb} MB"
275
+ except Exception:
276
+ telemetry_text = "📊 System Status: [Telemetry Error]"
277
+
278
+ # Use markdown with a custom class for styling
279
+ telemetry_placeholder.markdown(f'<div class="telemetry-box">{telemetry_text}</div>', unsafe_allow_html=True)
280
+
281
+
282
+ # Initial telemetry update when the script starts
283
+ update_telemetry()
284
+
285
+
286
+ # --- Caching Model Loading (Core Logic Only) ---
287
+ # Use st.cache_resource for heavy objects like models and tokenizers.
288
+ # This function MUST NOT call Streamlit elements that affect the layout
289
+ # or state outside of its own scope.
290
+ @st.cache_resource(show_spinner=False) # Spinner handled manually
291
+ def _load_model_and_tokenizer_cached(model_name: str, device: str, forced_model_type: str = None):
292
+ """
293
+ Loads the model and tokenizer. This function is cached and should
294
+ contain minimal Streamlit calls to avoid caching issues.
295
+ """
296
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
297
+ is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
298
+ detected_type = "Seq2Seq" if is_encoder_decoder else "Causal"
299
+
300
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
301
+ # Ensure padding token is set for generation robustness
302
+ if tokenizer.pad_token is None:
303
+ if tokenizer.eos_token is not None:
304
+ tokenizer.pad_token = tokenizer.eos_token
305
+ else:
306
+ # Fallback - adding tokens might require resizing model embeddings
307
+ # which is complex and model-dependent. This is a basic attempt.
308
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
309
+ tokenizer.pad_token = '[PAD]' # Set the attribute
310
+ # Attempt to get the new pad token ID - may not work for all tokenizers
311
+ try:
312
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
313
+ except Exception:
314
+ tokenizer.pad_token_id = None # Indicate failure to get ID
315
+
316
+ # Determine the model class based on detection or forced selection
317
+ actual_model_type = forced_model_type if forced_model_type != "Auto" else detected_type
318
+
319
+ if actual_model_type == "Seq2Seq":
320
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config, trust_remote_code=True)
321
+ elif actual_model_type == "Causal":
322
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True)
323
+ else:
324
+ raise ValueError(f"Unsupported model type selected: {actual_model_type}. Please select 'Auto', 'Causal', or 'Seq2Seq'.")
325
+
326
+ model.to(device)
327
+ model.eval() # Crucial for consistent inference behavior and disabling dropout etc.
328
+
329
+ # Ensure return_dict_in_generate is True for structured outputs
330
+ if not getattr(model.config, 'return_dict_in_generate', False):
331
+ model.config.return_dict_in_generate = True
332
+
333
+ return model, tokenizer, actual_model_type
334
+
335
+ # --- Wrapper function to handle status reporting for cached loading ---
336
+ def safe_load_model_with_status(model_name: str, device: str, forced_model_type: str = None):
337
+ """
338
+ Calls the cached loading function and handles Streamlit status updates.
339
+ """
340
+ status_text = f"🌐 Loading model '{model_name}' on device '{device}'..."
341
+ # Use st.status here, defined outside the cached function
342
+ with st.status(status_text, expanded=True) as status_box:
343
+ status_box.write("Checking system status...")
344
+ update_telemetry() # Update the separate telemetry box
345
+
346
+ try:
347
+ status_box.write("Loading configuration and tokenizer...")
348
+ # Call the actual cached loading function
349
+ model, tokenizer, actual_model_type = _load_model_and_tokenizer_cached(
350
+ model_name=model_name,
351
+ device=device,
352
+ forced_model_type=forced_model_type
353
+ )
354
+
355
+ # Report padding token status if available
356
+ if tokenizer and tokenizer.pad_token_id is None:
357
+ status_box.warning(f"Tokenizer has no pad_token_id. Generation might fail for models requiring padding (e.g., batching).")
358
+ elif tokenizer:
359
+ status_box.write(f"Tokenizer pad_token_id set to {tokenizer.pad_token_id}.")
360
+
361
+
362
+ status_box.success(f"✅ Model '{model_name}' ({actual_model_type}) loaded successfully on '{device}'.")
363
+ update_telemetry() # Final telemetry update after success
364
+ return model, tokenizer, actual_model_type
365
+
366
+ except Exception as e:
367
+ status_box.error(f"❌ Model loading failed.")
368
+ update_telemetry() # Final telemetry update after error
369
+ st.exception(e) # Display the full exception traceback
370
+ # Clean up resources in case of failure before returning None
371
+ # These are manual attempts; cache handles cleanup on its own state changes
372
+ # but explicit cleanup is good practice on error paths.
373
+ try:
374
+ if 'model' in locals() and model is not None: del model
375
+ except NameError: pass
376
+ try:
377
+ if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
378
+ except NameError: pass
379
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
380
+ gc.collect()
381
+ return None, None, None # Return None on failure
382
+
383
+
384
+ # --- Sidebar Configuration ---
385
+ with st.sidebar:
386
+ st.header("⚙️ Core Settings")
387
+ st.markdown("Configure the foundational aspects of the NeuroReasoner.")
388
+
389
+ with st.expander("🧠 Model Configuration", expanded=True):
390
+ model_name = st.text_input(
391
+ "Hugging Face Model ID or Path",
392
+ "ayjays132/NeuroReasoner-1-NR-1",
393
+ help="Enter the model ID from huggingface.co or a local path."
394
+ )
395
+
396
+ # --- Dynamic Model Type Detection ---
397
+ detected_type = "Unknown (Enter Model ID)"
398
+ # Options match the strings used in the loading function
399
+ model_type_options = ["Auto", "Causal", "Seq2Seq"]
400
+ default_model_type_index = model_type_options.index("Auto")
401
+
402
+ # Attempt to load config to detect type without caching (lightweight check)
403
+ try:
404
+ if model_name and model_name.strip(): # Only attempt if input is not empty
405
+ initial_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
406
+ is_encoder_decoder_initial = getattr(initial_config, "is_encoder_decoder", False)
407
+ detected_type = "Seq2Seq" if is_encoder_decoder_initial else "Causal"
408
+ else:
409
+ detected_type = "Unknown (Enter Model ID)"
410
+ except Exception:
411
+ detected_type = "Unknown (Config Load Error)" # Indicate config load itself failed
412
+
413
+ forced_model_type = st.selectbox(
414
+ "Architecture Type",
415
+ model_type_options,
416
+ index=default_model_type_index,
417
+ help=f"Detected: {detected_type}. 'Auto' uses the detected type. Select manually if detection is incorrect or overridden."
418
+ )
419
+
420
+ # --- Device Selection ---
421
+ available_devices = ["cpu"]
422
+ if torch.cuda.is_available():
423
+ available_devices.insert(0, "cuda") # Put cuda first if available
424
+
425
+ device = st.selectbox(
426
+ "Device",
427
+ available_devices,
428
+ help="Select the hardware device for computation (GPU recommended)."
429
+ )
430
+
431
+ st.markdown("""
432
+ <small>💡 Changing model settings requires reloading the model.</small>
433
+ """, unsafe_allow_html=True)
434
+
435
+
436
+ st.markdown("---") # Visual separator
437
+
438
+ st.header("✨ Generation Parameters")
439
+ st.markdown("Define how the AI generates reasoning steps and answers.")
440
+
441
+ with st.expander("Basic Parameters", expanded=True):
442
+ # Finalized 'Number of Reasoning Chains' parameter
443
+ num_chains = st.slider(
444
+ "Number of Reasoning Chains",
445
+ min_value=1,
446
+ max_value=15, # Kept the higher max for more robustness
447
+ value=5, # Kept the default of 5
448
+ help="How many independent reasoning chains to generate for analyzing the problem. More chains can improve Self-Consistency but take longer."
449
+ )
450
+
451
+ # Finalized 'No-repeat Ngram Size' parameter
452
+ no_repeat_ngram_size = st.slider( # Using the standard name for GenerationConfig
453
+ "No-repeat Ngram Size",
454
+ min_value=0,
455
+ max_value=10,
456
+ value=3,
457
+ help="Avoids generating repeating sequences of N tokens. Set to 0 to disable."
458
+ )
459
+
460
+ # Self-Consistency checkbox remains
461
+ self_consistency = st.checkbox(
462
+ "Enable Self-Consistency Voting",
463
+ value=True,
464
+ help="When enabled, the system generates multiple chains and identifies the most common final answer as the consensus. Requires 'Number of Reasoning Chains' > 1."
465
+ )
466
+
467
+ # Conditional warning if Self-Consistency is on but num_chains is 1
468
+ if self_consistency and num_chains <= 1:
469
+ st.warning("Self-Consistency is most effective with 2 or more chains.") # Slightly rephrased warning
470
+
471
+
472
+ # Advanced Parameters
473
+ with st.expander("🧪 Advanced Sampling Parameters"):
474
+ max_new_tokens = st.slider(
475
+ "Max Tokens per Chain",
476
+ 50, 2048, 768,
477
+ help="Maximum number of new tokens to generate for *each* individual reasoning chain. Adjust based on complexity expected."
478
+ )
479
+ temperature = st.slider(
480
+ "Temperature",
481
+ 0.0, 2.0, 0.8,
482
+ help="Controls the randomness of sampling. 0.0 is deterministic (greedy). Higher values increase diversity."
483
+ )
484
+ top_k = st.slider(
485
+ "Top-k",
486
+ 0, 100, 50,
487
+ help="Filter to consider only the top_k most likely tokens at each step (0 disables). Used with sampling."
488
+ )
489
+ top_p = st.slider(
490
+ "Top-p (Nucleus Sampling)",
491
+ 0.0, 1.0, 0.95,
492
+ help="Filter to consider tokens with cumulative probability below top_p (0.0 disables). Used with sampling."
493
+ )
494
+ do_sample = st.checkbox(
495
+ "Enable Sampling",
496
+ value=True,
497
+ help="If checked, uses probabilistic sampling (controlled by Temperature, Top-k, Top-p). If unchecked, uses greedy decoding."
498
+ )
499
+ if not do_sample:
500
+ st.info("Sampling disabled. Temperature, Top-k, and Top-p will be ignored.")
501
+
502
+ no_repeat_ngram_size = st.slider(
503
+ "No-repeat Ngram Size",
504
+ 0, 10, 3,
505
+ help="Avoids repeating sequences of N tokens. Set to 0 to disable."
506
+ )
507
+ # Optional: Add a seed for reproducibility if desired
508
+ # generation_seed = st.number_input("Generation Seed (Optional)", value=-1, help="Set a positive integer for reproducible generation.")
509
+
510
+
511
+ st.markdown("---") # Visual separator
512
+
513
+ # Update the persistent telemetry box in the sidebar footer area
514
+ update_telemetry()
515
+
516
+
517
+ # --- Main Content Layout ---
518
+ st.title("🧠 NeuroReasoner: Chain-of-Thought Explorer")
519
+ st.markdown("Unpack complex problems with step-by-step AI reasoning.")
520
+
521
+ # Container for input and primary controls
522
+ input_container = st.container()
523
+
524
+ with input_container:
525
+ # Use columns for prompt input and action button
526
+ prompt_col, button_col = st.columns([3, 1])
527
+
528
+ with prompt_col:
529
+ prompt = st.text_area(
530
+ "📝 Enter your query or problem:",
531
+ height=150,
532
+ placeholder="Example: If a train travels at 60 mph and a car at 40 mph, starting at the same time from cities 300 miles apart, how long until they meet? Think step-by-step.",
533
+ key="user_prompt" # Added key for stability
534
+ )
535
+
536
+ with button_col:
537
+ # Add some vertical space to align the button nicely
538
+ st.markdown("<div style='height: 3.5rem;'></div>", unsafe_allow_html=True)
539
+ run_button = st.button("✨ Generate Reasoning", use_container_width=True, key="generate_button") # Added key
540
+
541
+ # Container for status updates and results
542
+ results_container = st.container()
543
+
544
+
545
+ # --- Generation Logic Trigger ---
546
+ if run_button:
547
+ if not prompt or not prompt.strip():
548
+ results_container.warning("Please enter a prompt to begin generation.")
549
+ st.stop() # Stop execution until prompt is entered
550
+
551
+ # --- Prepare for Generation ---
552
+ # Load model and tokenizer (handles caching internally with st.cache_resource
553
+ # via safe_load_model_with_status which also reports status)
554
+ # This happens only when the button is clicked and parameters might have changed
555
+ model, tokenizer, loaded_model_type = safe_load_model_with_status(model_name, device, forced_model_type)
556
+
557
+ if model is None or tokenizer is None:
558
+ # Error was already shown by safe_load_model_with_status
559
+ st.error("Model or tokenizer failed to load. Please check settings and traceback above.")
560
+ st.stop() # Stop if loading failed
561
+
562
+
563
+ # --- Configure Generation ---
564
+ # Use a status box for ongoing generation process
565
+ with results_container:
566
+ st.markdown("---") # Separator before results
567
+ generation_status = st.status("Preparing generation config...", expanded=True)
568
+ update_telemetry() # Update telemetry while status is active
569
+
570
+
571
+ try:
572
+ # Build GenerationConfig based on sidebar parameters
573
+ # num_return_sequences should match num_chains for the wrapper to process them
574
+ gen_cfg = GenerationConfig(
575
+ max_new_tokens=max_new_tokens,
576
+ temperature=temperature,
577
+ top_k=top_k,
578
+ top_p=top_p,
579
+ do_sample=do_sample,
580
+ num_return_sequences=num_chains, # <--- Corrected: Use num_chains directly
581
+ no_repeat_ngram_size=no_repeat_ngram_size,
582
+ eos_token_id=tokenizer.eos_token_id,
583
+ pad_token_id=tokenizer.pad_token_id,
584
+ return_dict_in_generate=True,
585
+ output_scores=False,
586
+ output_attentions=False,
587
+ output_hidden_states=False,
588
+ use_cache=True,
589
+ )
590
+ generation_status.write(f"Generation parameters set: {gen_cfg.to_dict()}")
591
+ update_telemetry()
592
+
593
+ cfg = GenerationConfig(
594
+ max_new_tokens=max_new_tokens,
595
+ temperature=temperature,
596
+ top_k=top_k,
597
+ top_p=top_p,
598
+ do_sample=True,
599
+ num_return_sequences=num_chains,
600
+ no_repeat_ngram_size=no_repeat_ngram_size,
601
+ eos_token_id=tokenizer.eos_token_id,
602
+ pad_token_id=tokenizer.pad_token_id
603
+ )
604
+ except Exception as e:
605
+ generation_status.error(f"❌ Failed to create GenerationConfig: {e}")
606
+ st.exception(e)
607
  st.stop()
608
+
609
+ # --- Instantiate Wrapper ---
610
  try:
611
+ generation_status.write("Initializing Chain-of-Thought wrapper...")
612
+ # Pass the configured generation config to the wrapper
613
+ # The wrapper should internally use num_return_sequences from gen_cfg
614
+ cot_wrapper = ChainOfThoughtWrapper(
615
+ model=model,
616
+ tokenizer=tokenizer,
617
+ generation_config=cfg,
618
+ device=device,
619
+ self_consistency=self_consistency,
620
+ consistency_rounds=(num_chains if self_consistency else 1)
621
+ )
622
+ generation_status.write("Wrapper initialized.")
623
+ update_telemetry()
624
+
625
  except Exception as e:
626
+ generation_status.error(f"❌ Failed to initialize CoT wrapper: {e}")
627
+ st.exception(e)
628
  st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
+ # --- Tokenize Input ---
631
+ try:
632
+ generation_status.write("Tokenizing input prompt...")
633
+ # Use model_max_length or a reasonable cap for input length
634
+ max_input_length = tokenizer.model_max_length
635
+ if max_input_length is None or max_input_length > 4096: # Cap input length if tokenizer reports None or very large
636
+ max_input_length = 4096
637
+ if tokenizer.model_max_length is None:
638
+ generation_status.warning(f"Tokenizer has no model_max_length, capping input to {max_input_length}.")
639
+
640
+
641
+ enc = tokenizer(
642
+ prompt,
643
+ return_tensors='pt',
644
+ padding='longest', # Pad to the longest sequence in the batch (batch size is 1 here)
645
+ truncation=True,
646
+ max_length=max_input_length, # Use a proper max length for the input
647
+ ).to(device)
648
+ generation_status.write(f"Input token length: {enc['input_ids'].shape[1]}")
649
+ update_telemetry()
650
+
651
+ except Exception as e:
652
+ generation_status.error(f"❌ Tokenization failed: {e}")
653
+ st.exception(e)
654
+ st.stop()
655
+
656
+ # --- Generate ---
657
+ generation_status.update(label=f"⏳ Generating {num_chains} reasoning chains...", state="running")
658
+ start_time = time.time()
659
+
660
+ try:
661
+ # Call the wrapper's generate method
662
+ # It should handle the loop for multiple chains and self-consistency internally
663
+ outputs = cot_wrapper.generate(
664
+ input_ids=enc['input_ids'],
665
+ attention_mask=enc['attention_mask'],
666
+ # Pass any other necessary arguments to your wrapper's generate method
667
+ )
668
+ # Expected `outputs` dict structure: {'full_texts': [...], 'reasoning_steps': [...], 'final_answers': [...], 'consensus_answer': '...'}
669
+ # The wrapper should handle extracting steps/answers if needed.
670
+
671
+ except Exception as e:
672
+ generation_status.error(f"❌ Generation failed: {e}")
673
+ st.exception(e)
674
+ # Clean up resources after potential OOM or other errors
675
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
676
+ gc.collect() # Python garbage collection
677
+ st.stop()
678
+
679
+ elapsed_time = time.time() - start_time
680
+ generation_status.update(label=f"✨ Generation complete in {elapsed_time:.2f}s", state="complete")
681
+ update_telemetry() # Final telemetry update after successful generation
682
+
683
+ # --- Display Results ---
684
+ with results_container:
685
+ st.markdown("## 📚 Reasoning Output")
686
+
687
+ # Display Self-Consistency Consensus first if enabled and results are available
688
+ if self_consistency and outputs and 'consensus_answer' in outputs and outputs.get('final_answers'):
689
+ consensus = outputs.get('consensus_answer')
690
+ answers = outputs.get('final_answers', [])
691
+
692
+ st.markdown('<div class="consensus-answer">', unsafe_allow_html=True)
693
+ st.write("💡 **Consensus Answer (Self-Consistency):**")
694
+ st.write(consensus if consensus else "[Could not determine consensus]")
695
+ st.markdown('</div>', unsafe_allow_html=True)
696
+
697
+ if answers and len(answers) > 1: # Only show distribution if more than one answer was found
698
+ st.markdown("###### Answer Distribution:")
699
+ answer_counts = Counter(answers)
700
+ # Display sorted distribution
701
+ for ans, count in answer_counts.most_common():
702
+ st.write(f"- '{ans}' ({count} {'vote' if count == 1 else 'votes'})")
703
+ st.markdown("---") # Separator
704
+
705
+
706
+ # Display individual chains
707
+ full_texts = outputs.get('full_texts', [])
708
+ reasoning_steps = outputs.get('reasoning_steps', [])
709
+ final_answers = outputs.get('final_answers', [])
710
+
711
+ if not full_texts:
712
+ st.warning("No reasoning chains were generated.")
713
+ else:
714
+ st.markdown(f"### Individual Chains ({len(full_texts)} generated)")
715
+ # Iterate and display each chain in an expander
716
+ # Ensure lists are iterable, even if empty
717
+ full_texts = full_texts if isinstance(full_texts, list) else []
718
+ reasoning_steps = reasoning_steps if isinstance(reasoning_steps, list) else []
719
+ final_answers = final_answers if isinstance(final_answers, list) else []
720
+
721
+ # Pad lists to the same length in case the wrapper returned inconsistent outputs
722
+ max_len_outputs = max(len(full_texts), len(reasoning_steps), len(final_answers))
723
+ full_texts.extend(["[N/A - Generation Failed for this chain]"] * (max_len_outputs - len(full_texts)))
724
+ reasoning_steps.extend([[]] * (max_len_outputs - len(reasoning_steps)))
725
+ final_answers.extend(["[N/A]"] * (max_len_outputs - len(final_answers)))
726
+
727
+
728
+ for idx, (text, steps, ans) in enumerate(zip(full_texts, reasoning_steps, final_answers), 1):
729
+ # Use try-except just in case a single chain output is malformed
730
+ try:
731
+ # Expander for each chain, starting collapsed
732
+ with st.expander(f"Chain {idx}", expanded=False):
733
+ # Use custom class for styling the label
734
+ st.markdown('<div class="output-label">Full Generated Text:</div>', unsafe_allow_html=True)
735
+ # Use custom class for styling the text area background
736
+ st.text_area(f"chain_text_area_{idx}", text, height=250, label_visibility="collapsed", help="The complete generated output for this chain.")
737
+
738
+ if steps and isinstance(steps, list):
739
+ st.markdown('<div class="output-label">Reasoning Steps:</div>', unsafe_allow_html=True)
740
+ # Display steps as a list
741
+ if steps:
742
+ for i, step in enumerate(steps, 1):
743
+ if isinstance(step, str) and step.strip():
744
+ st.write(f"**Step {i}:** {step.strip()}")
745
+ elif not isinstance(step, str):
746
+ st.warning(f"Step {i} has invalid format.")
747
+ else:
748
+ st.info("No specific steps were extracted for this chain.")
749
+
750
+
751
+ st.markdown('<div class="output-label">Final Answer:</div>', unsafe_allow_html=True)
752
+ st.write(f"**{ans if ans else '[No answer extracted]'}**")
753
+
754
+ # Optional: Add a separator between chain sections
755
+ st.markdown("---", help="End of Chain details.")
756
+
757
+ except Exception as chain_e:
758
+ st.error(f"Error displaying Chain {idx}: {chain_e}")
759
+ st.exception(chain_e)
760
+
761
+
762
+ st.markdown("---") # Final separator
763
+ st.info("Generation process concluded. Review the chains above.")
764
+
765
+ # Clean up GPU memory after generation is complete and results are displayed
766
+ if torch.cuda.is_available():
767
+ torch.cuda.empty_cache()
768
+ gc.collect() # Python garbage collection
chain_of_thought_wrapper.py CHANGED
@@ -1,27 +1,67 @@
 
 
1
  import re
2
  import torch
3
  import logging
4
  from transformers import PreTrainedModel, AutoTokenizer, GenerationConfig, GenerationMixin
 
5
  from typing import Optional, List, Tuple, Dict, Union, Any
 
 
6
 
 
 
 
7
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
8
 
9
- # Default configuration values
10
- DEFAULT_MAX_LENGTH = 1024
11
- DEFAULT_REASONING_LIMIT = 10
12
- DEFAULT_CONSISTENCY_ROUNDS = 3
13
- DEFAULT_COMPLEXITY_KEYWORDS = ["explain", "step by step", "plan", "analyze", "reasoning", "logic"]
14
- DEFAULT_FINAL_ANSWER_TAG = "Final_Answer:"
 
15
 
16
- # **Expanded** step‐pattern to catch both "Step 1:" and bare "1."
 
 
 
 
 
 
 
 
 
17
  DEFAULT_STEP_PATTERN = re.compile(
18
  r"^(?:Step\s*\d+[:.)-]|\d+[:.)-])\s*(.*)", re.IGNORECASE
19
  )
20
 
 
21
  class ChainOfThoughtWrapper:
22
  """
23
- A robust, SOTA Chain-of-Thought wrapper for Hugging Face models or custom wrappers.
24
- ALWAYS uses Chain‑of‑Thought now, with stricter injection and cleaning.
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  """
26
 
27
  def __init__(
@@ -31,80 +71,201 @@ class ChainOfThoughtWrapper:
31
  generation_config: Optional[GenerationConfig] = None,
32
  device: Optional[str] = None,
33
  max_length: int = DEFAULT_MAX_LENGTH,
34
- reasoning_steps_limit: int = DEFAULT_REASONING_LIMIT,
35
- self_consistency: bool = False,
36
- consistency_rounds: int = DEFAULT_CONSISTENCY_ROUNDS,
37
- complexity_keywords: Optional[List[str]] = None,
38
  final_answer_tag: str = DEFAULT_FINAL_ANSWER_TAG,
 
39
  ):
40
  """
41
- model: HF model or wrapper implementing `.generate()`
42
- tokenizer: corresponding tokenizer
43
- generation_config: overrides defaults
44
- device: 'cpu'/'cuda'
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """
 
46
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
47
- self.model = model.to(self.device)
 
 
 
 
 
 
 
 
 
 
48
  self.tokenizer = tokenizer
 
 
49
  self.max_length = max_length
50
  self.reasoning_steps_limit = reasoning_steps_limit
51
- self.self_consistency = self_consistency
52
- self.consistency_rounds = max(1, consistency_rounds) if self_consistency else 1
53
- self.complexity_keywords = complexity_keywords or DEFAULT_COMPLEXITY_KEYWORDS
54
  self.final_answer_tag = final_answer_tag
 
55
  self.final_answer_pattern = re.compile(
56
  re.escape(final_answer_tag) + r"\s*(.*)", re.IGNORECASE | re.DOTALL
57
  )
 
 
58
 
59
- # Try to locate HF config; fallback to tokenizer if missing
 
60
  self._hf_model, self._hf_config = self._find_hf_model_and_config(self.model)
 
 
61
  if self._hf_config is None:
62
- logger.warning("HF config not found, falling back to tokenizer settings.")
 
63
  class PseudoConfig:
64
  def __init__(self, tok):
65
  self.eos_token_id = tok.eos_token_id
66
- self.pad_token_id = tok.pad_token_id or tok.eos_token_id
67
- self.vocab_size = len(tok)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  self._hf_config = PseudoConfig(self.tokenizer)
 
 
 
 
 
 
 
 
69
 
70
- # Setup generation config
 
 
71
  if generation_config:
 
72
  self.generation_config = GenerationConfig.from_dict(generation_config.to_dict())
 
73
  else:
 
74
  self.generation_config = GenerationConfig(
75
  eos_token_id=self._hf_config.eos_token_id,
76
  pad_token_id=self._hf_config.pad_token_id,
77
- max_length=self.max_length,
 
 
 
 
 
 
 
78
  )
 
79
 
80
- # Ensure HF model returns dict outputs
81
- try:
82
- setattr(self._hf_config, 'return_dict_in_generate', True)
83
- except Exception:
84
- pass
 
 
 
 
 
 
 
 
 
 
85
 
86
- logger.info("ChainOfThoughtWrapper ready on %s", self.device)
87
 
88
  def _find_hf_model_and_config(self, obj: Any) -> Tuple[Optional[PreTrainedModel], Optional[Any]]:
89
- """Search for underlying PreTrainedModel and its config."""
 
 
 
 
 
 
 
 
 
 
 
 
90
  if isinstance(obj, PreTrainedModel) and hasattr(obj, 'config'):
 
91
  return obj, obj.config
92
- for attr in ('model','base_model','transformer'):
93
- m = getattr(obj, attr, None)
94
- if isinstance(m, PreTrainedModel) and hasattr(m, 'config'):
95
- return m, m.config
96
- return None, getattr(obj, 'config', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def _inject_cot(self, prompt: str) -> str:
99
- # **More prescriptive CoT template**
100
- return (
101
- f"{prompt}\n\n"
102
- "Let's analyze step by step exactly like this:\n\n"
103
- "Step 1: \n"
104
- "Step 2: \n"
105
- "Step 3: \n\n"
106
- "Final Answer:\n\n"
 
 
 
 
 
 
 
 
 
 
 
107
  )
 
 
 
108
 
109
  @torch.no_grad()
110
  def generate(
@@ -112,75 +273,391 @@ class ChainOfThoughtWrapper:
112
  input_ids: torch.LongTensor,
113
  attention_mask: Optional[torch.LongTensor] = None,
114
  generation_config: Optional[GenerationConfig] = None,
115
- num_return_sequences: int = 1,
116
- **kwargs
117
  ) -> Dict[str, Any]:
118
  """
119
- Returns dict with keys: sequences, full_texts, reasoning_steps, final_answers
120
- ALWAYS uses CoT path.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  """
 
 
 
 
 
 
 
 
 
122
  prompt_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
123
 
124
- # **ALWAYS** do CoT, ignore complexity check
 
125
  cot_prompt = self._inject_cot(prompt_text)
 
126
 
127
- # Merge configs
128
- cfg = GenerationConfig.from_dict(self.generation_config.to_dict())
 
 
129
  if generation_config:
130
- cfg.update(**generation_config.to_dict())
 
 
 
131
  cfg.num_return_sequences = num_return_sequences
132
- for k,v in kwargs.items(): setattr(cfg, k, v)
133
 
134
- # Encode with injected template
135
- enc = self.tokenizer(
136
- cot_prompt, return_tensors='pt', truncation=True,
137
- max_length=self.max_length - cfg.max_new_tokens
138
- ).to(self.device)
 
 
 
 
 
 
 
 
139
 
140
- out = self.model.generate(
141
- input_ids=enc['input_ids'], attention_mask=enc['attention_mask'], generation_config=cfg
142
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- decoded = self.tokenizer.batch_decode(out, skip_special_tokens=True)
145
- results = [self._parse(text, cot_prompt) for text in decoded]
146
- seqs = out
147
- steps = [r[0] for r in results]
148
- finals = [r[1] for r in results]
149
- full = [r[2] for r in results]
150
- return {'sequences': seqs, 'full_texts': full, 'reasoning_steps': steps, 'final_answers': finals}
151
 
152
  def _parse(self, text: str, cot_prompt: str) -> Tuple[List[str], str, str]:
153
- # Remove the injected prompt
154
- body = text[len(cot_prompt):].strip() if text.startswith(cot_prompt) else text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- # **Clean out any stray tags or JSON fragments**
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  body = re.sub(r"<init>.*?</init>", "", body, flags=re.DOTALL)
158
  body = re.sub(r"<final_output>.*?</final_output>", "", body, flags=re.DOTALL)
 
 
 
159
  body = re.sub(r"\{.*?\}", "", body, flags=re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- lines = [l.strip() for l in body.splitlines() if l.strip()]
162
- steps = []
163
- final = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- for l in lines:
166
- m = DEFAULT_STEP_PATTERN.match(l)
167
- if m:
168
- steps.append(m.group(1).strip())
 
 
 
 
 
 
 
 
169
  else:
170
- fa = self.final_answer_pattern.search(l)
171
- if fa:
172
- final = fa.group(1).strip()
173
- break
174
 
175
- if not final:
176
- # assume last non‑step line is the final answer
177
- final = lines[-1] if lines else ""
178
 
179
- return steps, final, body
180
 
181
  def resize_token_embeddings(self, new_size: int):
182
- if hasattr(self._hf_model, 'resize_token_embeddings'):
183
- self._hf_model.resize_token_embeddings(new_size)
184
- logger.info("Resized embeddings to %d", new_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  else:
186
- logger.error("Cannot resize: no underlying HF model method.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chain_of_thought_wrapper.py
2
+
3
  import re
4
  import torch
5
  import logging
6
  from transformers import PreTrainedModel, AutoTokenizer, GenerationConfig, GenerationMixin
7
+ from transformers.utils import is_accelerate_available, is_bitsandbytes_available
8
  from typing import Optional, List, Tuple, Dict, Union, Any
9
+ import gc # Import garbage collector for cleanup
10
+ import time
11
 
12
+ # --- Logging Setup ---
13
+ # Configure logging for the module
14
+ logging.basicConfig(level=logging.INFO) # Default logging level
15
  logger = logging.getLogger(__name__)
16
+ # Prevent duplicate handlers if imported multiple times
17
+ if not logger.handlers:
18
+ handler = logging.StreamHandler()
19
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
+ handler.setFormatter(formatter)
21
+ logger.addHandler(handler)
22
+ logger.propagate = False # Prevent logs from going to root logger multiple times
23
+
24
 
25
+ # --- Default Configuration Values ---
26
+ # These defaults provide sensible starting points for the wrapper's behavior.
27
+ DEFAULT_MAX_LENGTH = 1024 # Default maximum length of the generated output sequence.
28
+ DEFAULT_REASONING_LIMIT = 10 # Limit on the number of steps to extract during parsing (currently unused in parse logic, but good to keep as a concept).
29
+ DEFAULT_CONSISTENCY_ROUNDS = 3 # Default number of chains to generate for self-consistency (used in __init__, passed via GUI num_chains).
30
+ DEFAULT_COMPLEXITY_KEYWORDS = ["explain", "step by step", "plan", "analyze", "reasoning", "logic"] # Keywords to potentially trigger CoT (currently unused, CoT is always on).
31
+ DEFAULT_FINAL_ANSWER_TAG = "Final_Answer:" # The specific tag expected before the final answer.
32
 
33
+ # --- Regex Pattern for Parsing Steps ---
34
+ # This pattern is used to identify and extract individual reasoning steps from
35
+ # the generated text. It's designed to be flexible, capturing:
36
+ # - "Step N:"
37
+ # - "Step N."
38
+ # - "Step N-"
39
+ # - "N:"
40
+ # - "N."
41
+ # - "N-"
42
+ # Where N is one or more digits, case-insensitive for "Step".
43
  DEFAULT_STEP_PATTERN = re.compile(
44
  r"^(?:Step\s*\d+[:.)-]|\d+[:.)-])\s*(.*)", re.IGNORECASE
45
  )
46
 
47
+
48
  class ChainOfThoughtWrapper:
49
  """
50
+ A robust Chain-of-Thought (CoT) wrapper for Hugging Face models.
51
+
52
+ This wrapper enforces a Chain-of-Thought process by injecting a specific
53
+ template into the prompt. It handles model generation and parses the
54
+ output to extract reasoning steps and a final answer. It is designed
55
+ to generate multiple sequences for potential Self-Consistency voting
56
+ (voting logic is expected to be handled by the calling application,
57
+ like the Streamlit GUI).
58
+
59
+ Key Features:
60
+ - Forces CoT via prompt injection.
61
+ - Parses structured reasoning steps and final answer from output.
62
+ - Supports generating multiple chains for Self-Consistency analysis.
63
+ - Compatible with Hugging Face PreTrainedModels or objects implementing `.generate()`.
64
+ - Handles device placement and merges GenerationConfig.
65
  """
66
 
67
  def __init__(
 
71
  generation_config: Optional[GenerationConfig] = None,
72
  device: Optional[str] = None,
73
  max_length: int = DEFAULT_MAX_LENGTH,
74
+ reasoning_steps_limit: int = DEFAULT_REASONING_LIMIT, # Parameter included as per provided code
75
+ self_consistency: bool = False, # Parameter included as per provided code (__init__ attribute)
76
+ consistency_rounds: int = DEFAULT_CONSISTENCY_ROUNDS, # Parameter included as per provided code (__init__ attribute)
77
+ complexity_keywords: Optional[List[str]] = None, # Parameter included as per provided code
78
  final_answer_tag: str = DEFAULT_FINAL_ANSWER_TAG,
79
+ # self_consistency_enabled: bool = False # Removed this based on user's 'keep as is' and gui interaction
80
  ):
81
  """
82
+ Initializes the ChainOfThoughtWrapper.
83
+
84
+ Args:
85
+ model (Union[PreTrainedModel, GenerationMixin, Any]): The language model.
86
+ Must have a `.generate()` method.
87
+ tokenizer (AutoTokenizer): The corresponding tokenizer.
88
+ generation_config (Optional[GenerationConfig]): A default generation configuration.
89
+ Values here can be overridden by `generate()` call.
90
+ device (Optional[str]): The device to load the model onto ('cpu' or 'cuda').
91
+ Defaults to 'cuda' if available, otherwise 'cpu'.
92
+ max_length (int): The maximum total length of the input + generated sequence.
93
+ reasoning_steps_limit (int): Conceptual limit for parsed steps (currently not enforced in _parse).
94
+ self_consistency (bool): Flag indicating if self-consistency is intended (Informs `consistency_rounds` attribute).
95
+ consistency_rounds (int): The number of chains to generate if self-consistency is active (Informs `consistency_rounds` attribute).
96
+ The actual number generated is controlled by `num_return_sequences` in `generate()` or `generation_config`.
97
+ complexity_keywords (Optional[List[str]]): List of keywords to potentially trigger CoT (currently unused).
98
+ final_answer_tag (str): The specific string marker expected before the final answer.
99
  """
100
+ # Determine and set the device
101
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
102
+ logger.info("Initializing wrapper on device: %s", self.device)
103
+
104
+ # Move the model to the specified device
105
+ try:
106
+ self.model = model.to(self.device)
107
+ self.model.eval() # Set model to evaluation mode for consistent behavior
108
+ logger.info("Model moved to %s and set to eval mode.", self.device)
109
+ except Exception as e:
110
+ logger.error("Failed to move model to device %s: %s", self.device, e)
111
+ raise # Re-raise the exception after logging
112
+
113
  self.tokenizer = tokenizer
114
+
115
+ # Set core parameters
116
  self.max_length = max_length
117
  self.reasoning_steps_limit = reasoning_steps_limit
118
+ self.self_consistency = self_consistency # Attribute stored, actual generation count controlled elsewhere
119
+ self.consistency_rounds = max(1, consistency_rounds) if self_consistency else 1 # Attribute stored
120
+ self.complexity_keywords = complexity_keywords or list(DEFAULT_COMPLEXITY_KEYWORDS) # Ensure it's a mutable list
121
  self.final_answer_tag = final_answer_tag
122
+ # Compile regex pattern for final answer extraction
123
  self.final_answer_pattern = re.compile(
124
  re.escape(final_answer_tag) + r"\s*(.*)", re.IGNORECASE | re.DOTALL
125
  )
126
+ logger.debug("Final answer pattern compiled: %s", self.final_answer_pattern.pattern)
127
+ logger.debug("Step pattern: %s", DEFAULT_STEP_PATTERN.pattern)
128
 
129
+ # Attempt to find the underlying Hugging Face model and its config
130
+ # This is useful for accessing standard attributes like eos_token_id, etc.
131
  self._hf_model, self._hf_config = self._find_hf_model_and_config(self.model)
132
+
133
+ # Fallback to tokenizer settings if HF config isn't found
134
  if self._hf_config is None:
135
+ logger.warning("Underlying HF model config not found. Relying on tokenizer for eos/pad tokens and vocab size.")
136
+ # Create a pseudo-config with essential tokenizer info
137
  class PseudoConfig:
138
  def __init__(self, tok):
139
  self.eos_token_id = tok.eos_token_id
140
+ # Use eos_token_id as pad_token_id if pad_token_id is None (common for GPT-like models)
141
+ self.pad_token_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
142
+ # Fallback if both are None (less common but possible)
143
+ if self.pad_token_id is None:
144
+ logger.warning("Tokenizer pad_token_id and eos_token_id are both None. Generation might be unstable without padding.")
145
+ # Assign a arbitrary value or handle externally if this happens in practice
146
+ # For now, keep it None, generation might fail or behave unexpectedly
147
+ pass # Keep pad_token_id as None
148
+
149
+ self.vocab_size = len(tok) # Vocabulary size from tokenizer
150
+
151
+ def __getattr__(self, name):
152
+ # Allow accessing other attributes, returning None if not found
153
+ # This prevents errors if generation_config tries to read something unexpected
154
+ logger.debug("Accessing undefined attribute '%s' on PseudoConfig. Returning None.", name)
155
+ return None
156
+
157
  self._hf_config = PseudoConfig(self.tokenizer)
158
+ logger.debug("Created PseudoConfig: eos_token_id=%s, pad_token_id=%s, vocab_size=%s",
159
+ self._hf_config.eos_token_id, self._hf_config.pad_token_id, self._hf_config.vocab_size)
160
+ else:
161
+ logger.info("Found underlying HF model config.")
162
+ logger.debug("HF Config: eos_token_id=%s, pad_token_id=%s, vocab_size=%s",
163
+ getattr(self._hf_config, 'eos_token_id', None),
164
+ getattr(self._hf_config, 'pad_token_id', None),
165
+ getattr(self._hf_config, 'vocab_size', None))
166
 
167
+
168
+ # --- Setup Generation Config ---
169
+ # Start with a base config, either provided or a default one
170
  if generation_config:
171
+ # Use from_dict and to_dict for safe merging/copying of GenerationConfig
172
  self.generation_config = GenerationConfig.from_dict(generation_config.to_dict())
173
+ logger.info("Initialized with provided GenerationConfig.")
174
  else:
175
+ # Create a default GenerationConfig using info from HF config or tokenizer fallback
176
  self.generation_config = GenerationConfig(
177
  eos_token_id=self._hf_config.eos_token_id,
178
  pad_token_id=self._hf_config.pad_token_id,
179
+ max_length=self.max_length, # Set max_length from wrapper param
180
+ # Add other common defaults if not provided
181
+ do_sample=True,
182
+ temperature=0.7,
183
+ top_p=0.95,
184
+ top_k=50,
185
+ num_return_sequences=1, # Default to 1 sequence
186
+ no_repeat_ngram_size=0, # Default to no ngram repetition prevention
187
  )
188
+ logger.info("Initialized with default GenerationConfig.")
189
 
190
+ # Ensure the underlying HF model (if found) is set to return dict outputs from generate
191
+ # This is necessary for accessing scores, hidden states etc. if needed, and for consistency.
192
+ # Use a check as some custom models might not have this attribute on their config.
193
+ if hasattr(self._hf_config, 'return_dict_in_generate'):
194
+ try:
195
+ setattr(self._hf_config, 'return_dict_in_generate', True)
196
+ logger.debug("Set _hf_config.return_dict_in_generate = True.")
197
+ except Exception as e:
198
+ logger.warning("Failed to set return_dict_in_generate on _hf_config: %s", e)
199
+ else:
200
+ logger.debug("_hf_config does not have return_dict_in_generate attribute.")
201
+
202
+
203
+ logger.info("ChainOfThoughtWrapper initialization complete on device: %s", self.device)
204
+ logger.debug("Initial GenerationConfig: %s", self.generation_config.to_dict())
205
 
 
206
 
207
  def _find_hf_model_and_config(self, obj: Any) -> Tuple[Optional[PreTrainedModel], Optional[Any]]:
208
+ """
209
+ Recursively searches for an underlying Hugging Face PreTrainedModel
210
+ and its configuration within a potentially wrapped object.
211
+
212
+ Args:
213
+ obj (Any): The object to inspect (could be the model itself or a wrapper).
214
+
215
+ Returns:
216
+ Tuple[Optional[PreTrainedModel], Optional[Any]]: The found HF model instance and its config.
217
+ Returns (None, None) if not found.
218
+ """
219
+ logger.debug("Searching for HF model in object of type: %s", type(obj))
220
+ # If the object is directly a PreTrainedModel and has a config
221
  if isinstance(obj, PreTrainedModel) and hasattr(obj, 'config'):
222
+ logger.debug("Found HF PreTrainedModel directly.")
223
  return obj, obj.config
224
+
225
+ # Check common attribute names where the base model might be stored
226
+ potential_attrs = ('model', 'base_model', 'transformer', 'hf_model')
227
+ for attr_name in potential_attrs:
228
+ m = getattr(obj, attr_name, None)
229
+ if m is not None:
230
+ logger.debug("Checking attribute '%s' of type %s", attr_name, type(m))
231
+ # Recursively search within the attribute
232
+ found_model, found_config = self._find_hf_model_and_config(m)
233
+ if found_model or found_config:
234
+ return found_model, found_config
235
+
236
+ # If no PreTrainedModel found, check if the object itself has a 'config' attribute
237
+ if hasattr(obj, 'config'):
238
+ logger.debug("Found config attribute on object, but no PreTrainedModel.")
239
+ return None, obj.config
240
+
241
+ logger.debug("No HF PreTrainedModel or config found.")
242
+ return None, None
243
+
244
 
245
  def _inject_cot(self, prompt: str) -> str:
246
+ """
247
+ Injects the prescriptive Chain-of-Thought template into the user's prompt.
248
+
249
+ This method defines the expected format the model should follow for reasoning.
250
+
251
+ Args:
252
+ prompt (str): The original user prompt.
253
+
254
+ Returns:
255
+ str: The prompt with the CoT template appended.
256
+ """
257
+ # The template strongly guides the model to produce step-by-step reasoning
258
+ # followed by a specific tag for the final answer.
259
+ cot_prompt = (
260
+ f"{prompt.strip()}\n\n" # Use strip() to clean user prompt
261
+ "Let's analyze this problem logically, breaking it down step by step to reach the precise final answer.\n\n" # Enhanced instruction
262
+ "Reasoning Process:\n\n" # Clearer heading for steps
263
+ "Step 1: " # Start the first step explicitly
264
+ # More steps are not needed here, the model learns to continue the pattern
265
  )
266
+ logger.debug("Injected CoT template. Full prompt starts with: %s...", cot_prompt[:100].replace('\n', '\\n'))
267
+ return cot_prompt
268
+
269
 
270
  @torch.no_grad()
271
  def generate(
 
273
  input_ids: torch.LongTensor,
274
  attention_mask: Optional[torch.LongTensor] = None,
275
  generation_config: Optional[GenerationConfig] = None,
276
+ num_return_sequences: int = 1, # This argument controls how many sequences are generated
277
+ **kwargs: Any # Allows passing arbitrary generation parameters
278
  ) -> Dict[str, Any]:
279
  """
280
+ Generates text using the wrapped model, enforcing Chain-of-Thought.
281
+
282
+ This method prepares the input by injecting the CoT template, calls the
283
+ underlying model's generate method, and then parses the raw outputs
284
+ to extract structured reasoning steps and final answers.
285
+
286
+ Args:
287
+ input_ids (torch.LongTensor): Tokenized input prompt (batch size 1 expected).
288
+ Shape [1, sequence_length].
289
+ attention_mask (Optional[torch.LongTensor]): Attention mask for the input.
290
+ Shape [1, sequence_length].
291
+ generation_config (Optional[GenerationConfig]): Specific generation config
292
+ for this call. Overrides defaults.
293
+ num_return_sequences (int): The number of independent sequences to generate.
294
+ This is crucial for Self-Consistency.
295
+ Comes from the GUI's 'num_chains'.
296
+ **kwargs (Any): Additional keyword arguments passed to the model's `generate` method.
297
+
298
+ Returns:
299
+ Dict[str, Any]: A dictionary containing:
300
+ - 'sequences' (torch.LongTensor): The raw generated token sequences.
301
+ - 'full_texts' (List[str]): The complete decoded text for each sequence.
302
+ - 'reasoning_steps' (List[List[str]]): List of parsed reasoning steps for each sequence.
303
+ - 'final_answers' (List[str]): List of parsed final answers for each sequence.
304
+ - 'consensus_answer' (Optional[str]): The consensus answer if self-consistency is active and possible (Handled by calling code).
305
  """
306
+ # Ensure input is on the correct device
307
+ input_ids = input_ids.to(self.device)
308
+ if attention_mask is not None:
309
+ attention_mask = attention_mask.to(self.device)
310
+
311
+ # Decode the original prompt text for CoT injection
312
+ # Assume batch size is 1 for the input prompt tensor [1, sequence_length]
313
+ if input_ids.size(0) != 1:
314
+ logger.warning("Batch size > 1 detected for input_ids (%d). CoT injection assumes batch size 1. Using the first item.", input_ids.size(0))
315
  prompt_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
316
 
317
+ # --- Inject CoT Template ---
318
+ # This is the core step that forces the model into a reasoning mode.
319
  cot_prompt = self._inject_cot(prompt_text)
320
+ logger.debug("Injected CoT prompt. Encoding...")
321
 
322
+ # --- Prepare Generation Configuration ---
323
+ # Merge the wrapper's default config with the call-specific config and kwargs.
324
+ # The num_return_sequences from the function argument takes precedence here.
325
+ cfg = GenerationConfig.from_dict(self.generation_config.to_dict()) # Start with wrapper's default
326
  if generation_config:
327
+ cfg.update(**generation_config.to_dict()) # Update with call-specific config
328
+ logger.debug("Updated GenerationConfig with call-specific config.")
329
+
330
+ # Explicitly set num_return_sequences from the function argument
331
  cfg.num_return_sequences = num_return_sequences
332
+ logger.info("Generating %d sequence(s).", cfg.num_return_sequences)
333
 
334
+ # Update with any remaining keyword arguments passed to generate()
335
+ for k, v in kwargs.items():
336
+ if hasattr(cfg, k):
337
+ setattr(cfg, k, v)
338
+ logger.debug("Updating GenerationConfig kwarg: %s=%s", k, v)
339
+ else:
340
+ # Allow passing arbitrary kwargs to model.generate if the underlying method supports them
341
+ # These won't be part of the GenerationConfig object itself unless it's a supported param.
342
+ # However, the model's generate method might accept extra args.
343
+ # Log a warning if it's not a standard GenerationConfig parameter.
344
+ if k not in GenerationConfig().__dict__: # Check if it's NOT a standard param
345
+ logger.debug("Passing non-standard kwarg '%s' to model.generate.", k)
346
+ # We pass all kwargs to model.generate below anyway.
347
 
348
+ logger.debug("Final GenerationConfig for call: %s", cfg.to_dict())
349
+
350
+ # --- Encode the CoT Prompt ---
351
+ # Max length for input should be total max_length minus max_new_tokens
352
+ # to leave space for the generation.
353
+ # Ensure padding and truncation are handled.
354
+ try:
355
+ enc = self.tokenizer(
356
+ cot_prompt,
357
+ return_tensors='pt',
358
+ padding='longest', # Pad to the longest sequence in the batch (always 1 here)
359
+ truncation=True, # Crucially, truncate if the prompt is too long
360
+ max_length=self.max_length - cfg.max_new_tokens # Leave room for generation
361
+ ).to(self.device)
362
+ logger.debug("Encoded CoT prompt. Input shape: %s", enc['input_ids'].shape)
363
+
364
+ except Exception as e:
365
+ logger.error("Failed to encode CoT prompt: %s", e)
366
+ raise # Re-raise the exception after logging
367
+
368
+
369
+ # --- Generate Text ---
370
+ # Call the underlying model's generate method with the prepared input and config.
371
+ # torch.no_grad() context is already applied to the whole method.
372
+ try:
373
+ logger.info("Calling model.generate()...")
374
+ start_time = time.time() # Measure generation time
375
+ out = self.model.generate(
376
+ input_ids=enc['input_ids'],
377
+ attention_mask=enc['attention_mask'],
378
+ generation_config=cfg,
379
+ **kwargs # Pass through any extra kwargs
380
+ )
381
+ elapsed_time = time.time() - start_time
382
+ logger.info("model.generate() finished in %.2f seconds.", elapsed_time)
383
+ logger.debug("Raw output shape: %s", out.shape)
384
+
385
+
386
+ except Exception as e:
387
+ logger.error("Model generation failed: %s", e)
388
+ # Attempt to clean up GPU memory in case of OOM or other errors
389
+ if torch.cuda.is_available():
390
+ torch.cuda.empty_cache()
391
+ gc.collect() # Trigger Python garbage collection
392
+ raise # Re-raise the exception after logging
393
+
394
+ # --- Decode and Parse Outputs ---
395
+ # Decode the generated token sequences back into text.
396
+ logger.debug("Decoding and parsing outputs...")
397
+ decoded_outputs = self.tokenizer.batch_decode(out, skip_special_tokens=True)
398
+
399
+ # Process each decoded output to extract steps and final answer
400
+ parsed_results = [self._parse(text, cot_prompt) for text in decoded_outputs]
401
+
402
+ # Separate the parsed components into lists
403
+ all_steps = [r[0] for r in parsed_results]
404
+ all_finals = [r[1] for r in parsed_results]
405
+ all_full_texts = [r[2] for r in parsed_results] # The 'body' after removing template
406
+
407
+ logger.info("Generated and parsed %d sequence(s).", len(decoded_outputs))
408
+
409
+ # --- Return Results ---
410
+ # The calling code (e.g., the GUI) is responsible for implementing
411
+ # Self-Consistency voting based on the list of 'final_answers' provided here.
412
+ return {
413
+ 'sequences': out, # Return raw sequences in case they are needed
414
+ 'full_texts': all_full_texts, # Text body after template removal
415
+ 'reasoning_steps': all_steps,
416
+ 'final_answers': all_finals,
417
+ # 'consensus_answer' is not computed here, it's done externally.
418
+ # Keeping the structure consistent with GUI expectation.
419
+ 'consensus_answer': None # Placeholder, computed externally
420
+ }
421
 
 
 
 
 
 
 
 
422
 
423
  def _parse(self, text: str, cot_prompt: str) -> Tuple[List[str], str, str]:
424
+ """
425
+ Parses the generated text to extract reasoning steps and the final answer.
426
+
427
+ Applies regex patterns to find lines matching the step format and the
428
+ final answer tag. Includes cleanup for stray model artifacts.
429
+
430
+ Args:
431
+ text (str): The raw text output from the model for a single chain.
432
+ cot_prompt (str): The exact prompt text that was injected (used to remove it from the output).
433
+
434
+ Returns:
435
+ Tuple[List[str], str, str]: A tuple containing:
436
+ - A list of extracted reasoning step strings.
437
+ - The extracted final answer string.
438
+ - The full body of the generated text (after removing the prompt).
439
+ """
440
+ logger.debug("Parsing generated text...")
441
 
442
+ # Remove the exact injected prompt from the beginning of the text.
443
+ # This isolates the model's generated continuation.
444
+ body = text
445
+ if text.startswith(cot_prompt):
446
+ body = text[len(cot_prompt):].strip()
447
+ logger.debug("Removed CoT prompt (%d characters) from beginning.", len(cot_prompt))
448
+ else:
449
+ logger.warning("Generated text does not start with the injected CoT prompt. Parsing entire text.")
450
+ body = text.strip() # Just strip whitespace if template wasn't followed
451
+
452
+ # --- Cleanup stray model artifacts ---
453
+ # Remove common problematic tags or partial JSON structures that models sometimes emit.
454
+ # This makes the raw output cleaner before step/answer extraction.
455
+ logger.debug("Cleaning stray artifacts...")
456
  body = re.sub(r"<init>.*?</init>", "", body, flags=re.DOTALL)
457
  body = re.sub(r"<final_output>.*?</final_output>", "", body, flags=re.DOTALL)
458
+ # Note: Removing all {} might be aggressive if model uses them naturally.
459
+ # Keeping it as per provided code, but be aware this could remove desired output.
460
+ # Consider making this optional or more specific if needed.
461
  body = re.sub(r"\{.*?\}", "", body, flags=re.DOTALL)
462
+ logger.debug("Artifact cleanup complete.")
463
+
464
+
465
+ lines = [l.strip() for l in body.splitlines() if l.strip()] # Split into non-empty, stripped lines
466
+ steps = [] # List to store extracted steps
467
+ final_answer = "" # Variable to store the final answer
468
+
469
+ # --- Extract Steps and Final Answer ---
470
+ # Iterate through lines and apply regex patterns.
471
+ found_final_answer_line = False
472
+ for i, line in enumerate(lines):
473
+ # Check for reasoning step pattern
474
+ step_match = DEFAULT_STEP_PATTERN.match(line)
475
+ if step_match:
476
+ # If a step is found, add the captured group (the text after the number/tag)
477
+ steps.append(step_match.group(1).strip())
478
+ logger.debug("Extracted step %d: '%s'", len(steps), steps[-1][:50])
479
+ # Stop adding steps if we've reached a defined limit (though limit isn't currently enforced after parsing)
480
+ # if len(steps) >= self.reasoning_steps_limit:
481
+ # logger.debug("Reached reasoning steps limit (%d). Stopping step extraction.", self.reasoning_steps_limit)
482
+ # # Continue iterating to potentially find the final answer after the limit
483
+ # # break # DO NOT break if we still need to find the final answer tag after the limit
484
+
485
 
486
+ else:
487
+ # If it's not a step, check for the final answer tag
488
+ final_answer_match = self.final_answer_pattern.search(line)
489
+ if final_answer_match:
490
+ # If the final answer tag is found, extract the text following it
491
+ final_answer = final_answer_match.group(1).strip()
492
+ logger.debug("Extracted final answer tagged: '%s'", final_answer[:50])
493
+ found_final_answer_line = True
494
+ # Once the final answer tag is found, we can stop processing lines for *this specific pattern*
495
+ # However, the provided code breaks the loop entirely here.
496
+ # Keeping the break to match the original logic.
497
+ break # Stop processing lines after finding the tagged answer
498
+
499
+ # --- Fallback for Final Answer ---
500
+ # If the specific final answer tag was not found, assume the last non-step line
501
+ # is the intended final answer. This is a heuristic fallback.
502
+ if not found_final_answer_line:
503
+ logger.debug("Final answer tag not found. Applying fallback heuristic.")
504
+ # Find the last line that is not a step
505
+ last_non_step_line = ""
506
+ for line in reversed(lines): # Iterate backwards
507
+ if line.strip() and not DEFAULT_STEP_PATTERN.match(line):
508
+ last_non_step_line = line.strip()
509
+ logger.debug("Fallback: Last non-step line found: '%s'", last_non_step_line[:50])
510
+ break # Found the last non-step line
511
 
512
+ if last_non_step_line:
513
+ # Check if the last non-step line *contains* the final answer tag,
514
+ # even if it didn't *start* with it or was the last line processed.
515
+ # This handles cases where the tag might be mid-line or in a different format.
516
+ fa_match_fallback = self.final_answer_pattern.search(last_non_step_line)
517
+ if fa_match_fallback:
518
+ final_answer = fa_match_fallback.group(1).strip()
519
+ logger.debug("Fallback found tagged answer in last non-step line: '%s'", final_answer[:50])
520
+ else:
521
+ # If no tag in the last non-step line, just use the line itself
522
+ final_answer = last_non_step_line
523
+ logger.debug("Fallback using last non-step line as answer: '%s'", final_answer[:50])
524
  else:
525
+ # If no non-empty lines were found, the final answer is empty
526
+ final_answer = ""
527
+ logger.debug("No lines found in body. Final answer is empty.")
528
+
529
 
530
+ logger.debug("Parsing complete. Steps found: %d, Final Answer: '%s'", len(steps), final_answer[:50])
531
+
532
+ return steps, final_answer, body # Return steps, final answer, and the cleaned body text
533
 
 
534
 
535
  def resize_token_embeddings(self, new_size: int):
536
+ """
537
+ Resizes the model's token embeddings, useful after adding new tokens
538
+ to the tokenizer (like a custom PAD token).
539
+
540
+ Only works if the underlying model object has a `resize_token_embeddings` method.
541
+
542
+ Args:
543
+ new_size (int): The new size of the vocabulary/embedding layer.
544
+ Should match the size of the tokenizer's vocabulary.
545
+ """
546
+ # Find the actual HF model if wrapped
547
+ hf_model_instance, _ = self._find_hf_model_and_config(self.model)
548
+
549
+ if hasattr(hf_model_instance, 'resize_token_embeddings'):
550
+ try:
551
+ old_size = hf_model_instance.get_input_embeddings().weight.size(0)
552
+ if new_size != old_size:
553
+ hf_model_instance.resize_token_embeddings(new_size)
554
+ logger.info("Resized model token embeddings from %d to %d.", old_size, new_size)
555
+ # Update model config's vocab size if available
556
+ if hasattr(hf_model_instance, 'config') and hasattr(hf_model_instance.config, 'vocab_size'):
557
+ hf_model_instance.config.vocab_size = new_size
558
+ logger.debug("Updated model config vocab_size to %d.", new_size)
559
+ else:
560
+ logger.info("Embedding size is already %d, no resizing needed.", new_size)
561
+ except Exception as e:
562
+ logger.error("Failed to resize token embeddings: %s", e)
563
+ # Attempt cleanup
564
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
565
+ gc.collect()
566
  else:
567
+ logger.error("Cannot resize token embeddings: The underlying model object does not have a 'resize_token_embeddings' method.")
568
+
569
+
570
+ # Example Usage (Illustrative - requires a real HF model and tokenizer)
571
+ if __name__ == "__main__":
572
+ print("--- ChainOfThoughtWrapper Example Usage ---")
573
+ print("This block requires a Hugging Face model to run.")
574
+ print("Loading a small dummy model for demonstration...")
575
+
576
+ # You would replace this with your actual model loading logic
577
+ try:
578
+ # Use a tiny, fast model for a quick test
579
+ model_id = "hf-internal-testing/tiny-random-gpt2"
580
+ device = "cuda" if torch.cuda.is_available() else "cpu"
581
+
582
+ logger.info(f"Attempting to load model {model_id} on {device}...")
583
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
584
+ model = AutoModelForCausalLM.from_pretrained(model_id)
585
+
586
+ # Ensure pad token is set for generation (common requirement)
587
+ if tokenizer.pad_token_id is None:
588
+ if tokenizer.eos_token_id is not None:
589
+ tokenizer.pad_token_id = tokenizer.eos_token_id
590
+ else:
591
+ # Add a pad token if neither eos nor pad exists
592
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
593
+ model.resize_token_embeddings(len(tokenizer)) # Resize embeddings after adding token
594
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
595
+ logger.warning("Added and set [PAD] token, resized embeddings.")
596
+
597
+ # Instantiate the wrapper
598
+ # Simulate parameters that would come from the GUI
599
+ simulated_gen_config = GenerationConfig(
600
+ max_new_tokens=100,
601
+ temperature=0.8,
602
+ do_sample=True,
603
+ num_return_sequences=2, # Simulate asking for 2 chains
604
+ pad_token_id=tokenizer.pad_token_id, # Pass pad_token_id explicitly
605
+ eos_token_id=tokenizer.eos_token_id, # Pass eos_token_id explicitly
606
+ )
607
+
608
+ cot_wrapper = ChainOfThoughtWrapper(
609
+ model=model,
610
+ tokenizer=tokenizer,
611
+ generation_config=simulated_gen_config,
612
+ device=device,
613
+ self_consistency=True, # Simulate SC enabled
614
+ consistency_rounds=2, # Simulate consistency rounds setting
615
+ )
616
+
617
+ # Prepare input prompt
618
+ prompt_text = "What is 2 + 2? Think step-by-step."
619
+ input_enc = tokenizer(prompt_text, return_tensors='pt').to(device)
620
+
621
+ logger.info(f"Generating reasoning for prompt: '{prompt_text}'")
622
+
623
+ # Generate outputs
624
+ # The num_return_sequences from simulated_gen_config will be used here
625
+ outputs = cot_wrapper.generate(
626
+ input_ids=input_enc['input_ids'],
627
+ attention_mask=input_enc['attention_mask']
628
+ )
629
+
630
+ # Process results (including simulated Self-Consistency voting logic)
631
+ print("\n--- Generation Results ---")
632
+ for i, (full_text, steps, final_answer) in enumerate(zip(outputs['full_texts'], outputs['reasoning_steps'], outputs['final_answers'])):
633
+ print(f"\n--- Chain {i+1} ---")
634
+ print("Full Text:")
635
+ print(full_text)
636
+ print("\nReasoning Steps:")
637
+ if steps:
638
+ for j, step in enumerate(steps):
639
+ print(f" Step {j+1}: {step}")
640
+ else:
641
+ print(" [No steps parsed]")
642
+ print("\nFinal Answer:")
643
+ print(f" {final_answer or '[No final answer parsed]'}")
644
+
645
+ # --- Simulate Self-Consistency Voting (as would be done in GUI) ---
646
+ print("\n--- Self-Consistency Voting ---")
647
+ final_answers = [ans for ans in outputs['final_answers'] if ans.strip()] # Filter empty answers
648
+ if final_answers:
649
+ answer_counts = Counter(final_answers)
650
+ most_common_answer, count = answer_counts.most_common(1)[0]
651
+ print(f"Raw Answers Submitted for Voting: {final_answers}")
652
+ print(f"Answer Counts: {dict(answer_counts)}")
653
+ print(f"Consensus Answer: '{most_common_answer}' (Voted by {count} chain(s))")
654
+ else:
655
+ print("No valid final answers found for voting.")
656
+
657
+
658
+ except Exception as e:
659
+ logger.error("Example usage failed: %s", e)
660
+ import traceback
661
+ traceback.print_exc() # Print detailed traceback for the example failure
662
+
663
+ print("\n--- Example Usage End ---")