ayjays132 commited on
Commit
aaa0e51
·
verified ·
1 Parent(s): 0b507b5

Upload 2 files

Browse files
Files changed (2) hide show
  1. chain_of_thought_gui.py +496 -367
  2. chain_of_thought_wrapper.py +670 -406
chain_of_thought_gui.py CHANGED
@@ -5,14 +5,15 @@ NeuroReasoner Chain-of-Thought GUI (Dark Theme Enhanced)
5
  A premium Streamlit app for step-by-step reasoning
6
  across any Hugging Face model (causal or seq2seq).
7
  Featuring a dark theme, model-type detection, self-consistency
8
- sampling, and robust handling.
9
  """
10
  import os
11
  import time
 
12
  import streamlit as st
13
  import torch
14
  import pynvml # For GPU telemetry
15
- import numpy as np
16
  from transformers import (
17
  AutoConfig,
18
  AutoTokenizer,
@@ -23,58 +24,66 @@ from transformers import (
23
  )
24
  from collections import Counter # For self-consistency voting
25
  import gc # Import garbage collector
26
-
27
- # Assuming chain_of_thought_wrapper.py is in the same directory
28
- # and is designed to work with standard Hugging Face models and GenerationConfig.
29
- # Make sure the wrapper correctly handles num_return_sequences for CoT and SC,
30
- # and returns the expected dictionary structure:
31
- # {'full_texts': [...], 'reasoning_steps': [...], 'final_answers': [...], 'consensus_answer': '...'}
 
 
 
32
  try:
33
  from chain_of_thought_wrapper import ChainOfThoughtWrapper
34
  except ImportError:
35
- st.error("Error: chain_of_thought_wrapper.py not found. Please ensure it's in the same directory.")
36
- st.stop()
37
 
 
 
 
 
38
 
39
  # --- Page Configuration ---
40
  st.set_page_config(
41
  page_title="🧠 NeuroReasoner CoT GUI",
42
  page_icon="🧠",
43
- layout="wide",
44
- initial_sidebar_state="expanded",
45
  menu_items={
46
- 'Get Help': 'https://github.com/your_repo_link_here', # Replace or remove
47
- 'Report a bug': "https://github.com/your_repo_link_here/issues", # Replace or remove
48
  'About': """
49
  **NeuroReasoner Chain-of-Thought GUI**
50
- An open-source interface powered by Hugging Face models and the NeuroReasoner wrapper.
51
  Explore step-by-step reasoning with various language models.
 
 
52
  """
53
  }
54
  )
55
 
56
  # --- Dark Theme CSS ---
 
57
  st.markdown("""
58
  <style>
59
  /* Overall Page Background & Text (Dark Theme) */
60
- body {
 
61
  background-color: #1E1E1E; /* Dark grey background */
62
  color: #D4D4D4; /* Light grey text */
63
  font-family: 'Segoe UI', Roboto, Arial, sans-serif;
64
  }
65
- .stApp {
66
- background-color: #1E1E1E;
67
- color: #D4D4D4;
68
- }
69
 
70
  /* Sidebar Styling */
71
  .stSidebar {
72
  background-color: #2D2D2D; /* Slightly lighter dark grey for sidebar */
73
  padding: 2rem 1rem;
74
  border-right: 1px solid #3E3E3E; /* Subtle border */
 
75
  }
76
  .stSidebar h1, .stSidebar h2, .stSidebar h3 {
77
- color: #569CD6; /* Visual Studio Code blue for sidebar headers */
78
  }
79
  .stSidebar label {
80
  color: #D4D4D4 !important; /* Ensure sidebar labels are visible */
@@ -82,9 +91,7 @@ st.markdown("""
82
 
83
 
84
  /* Main Content Area */
85
- .stContainer {
86
- padding: 2rem;
87
- }
88
 
89
  /* Titles and Headers */
90
  h1, h2, h3, h4, h5, h6 {
@@ -107,6 +114,7 @@ st.markdown("""
107
  font-weight: bold;
108
  transition: background-color 0.2s ease, transform 0.1s ease;
109
  box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
 
110
  }
111
  .stButton>button:hover {
112
  background-color: #27633A; /* Lighter green on hover */
@@ -120,7 +128,9 @@ st.markdown("""
120
 
121
 
122
  /* Text areas and inputs */
123
- .stTextArea textarea, .stTextInput input {
 
 
124
  border: 1px solid #3E3E3E; /* Dark border */
125
  border-radius: 0.4rem;
126
  padding: 0.75rem;
@@ -129,26 +139,28 @@ st.markdown("""
129
  color: #D4D4D4; /* Light text */
130
  box-shadow: inset 1px 1px 3px rgba(0, 0, 0, 0.2);
131
  }
132
- .stTextArea label, .stTextInput label {
 
 
133
  font-weight: bold;
134
  color: #9CDCFE !important; /* Light blue labels */
135
  margin-bottom: 0.5rem;
136
  display: block;
137
  }
138
- /* Streamlit status box styling */
139
- .st-emotion-cache-vj1l9j { /* Target the status box content div */
140
  background-color: #2D2D2D; /* Match sidebar background */
141
  border: 1px solid #3E3E3E;
142
  border-radius: 0.5rem;
143
  padding: 1rem;
144
  margin-bottom: 1rem;
145
  }
146
- .st-emotion-cache-vj1l9j .stMarkdown p { /* Style text inside status */
147
- color: #D4D4D4 !important;
148
- }
149
- /* Status box icons/text (might need to target specific internal classes) */
150
- .st-emotion-cache-vj1l9j .stAlert {
151
- background-color: transparent !important; /* Don't want alert backgrounds inside status */
152
  }
153
 
154
 
@@ -159,11 +171,12 @@ st.markdown("""
159
  padding: 1rem;
160
  font-size: 1rem;
161
  border-left: 5px solid transparent; /* Base style */
 
162
  }
163
- .stAlert.stAlert-info { border-left-color: #569CD6; background-color: #2A3E52; color: #9CDCFE; } /* Dark blue info */
164
- .stAlert.stAlert-success { border-left-color: #4EC9B0; background-color: #28403A; color: #7AC7A3; } /* Dark teal success */
165
- .stAlert.stAlert-warning { border-left-color: #DCDCAA; background-color: #454032; color: #FFDAA6; } /* Dark yellow warning */
166
- .stAlert.stAlert-error { border-left-color: #F44747; background-color: #4A3030; color: #F48787; } /* Dark red error */
167
 
168
 
169
  /* Expander styling */
@@ -192,6 +205,10 @@ st.markdown("""
192
  margin-top: 0;
193
  color: #D4D4D4;
194
  }
 
 
 
 
195
 
196
  /* Labels for the output text areas */
197
  .output-label {
@@ -204,7 +221,8 @@ st.markdown("""
204
  }
205
 
206
  /* Custom class for output text areas to differentiate from input */
207
- .output-text-area textarea {
 
208
  background-color: #1E1E1E; /* Even darker background for outputs */
209
  border: 1px solid #3E3E3E;
210
  border-radius: 0.4rem;
@@ -238,11 +256,13 @@ st.markdown("""
238
  font-weight: bold;
239
  }
240
  .consensus-answer strong {
241
- color: #4EC9B0; /* Teal for "Consensus Answer" label */
242
- }
243
- .consensus-answer div {
244
- color: #D4D4D4; /* Ensure the answer text is light */
245
  }
 
 
 
 
 
246
 
247
 
248
  </style>
@@ -250,35 +270,42 @@ st.markdown("""
250
 
251
 
252
  # --- GPU Telemetry Setup ---
 
253
  try:
254
  pynvml.nvmlInit()
255
  GPU_AVAILABLE = True
 
 
 
 
 
256
  except Exception:
257
  GPU_AVAILABLE = False
 
258
 
259
  # Use st.empty to hold the telemetry status text, defined *outside* cached functions
260
  telemetry_placeholder = st.empty()
261
 
262
  def update_telemetry():
263
  """Updates the telemetry display in the dedicated placeholder."""
264
- telemetry_text = "[Checking System Status...]"
265
  if not GPU_AVAILABLE or not torch.cuda.is_available():
266
  telemetry_text = "📊 System Status: [No GPU Available]"
267
  else:
268
  try:
269
- h = pynvml.nvmlDeviceGetHandleByIndex(0)
270
- u = pynvml.nvmlDeviceGetUtilizationRates(h)
271
- m = pynvml.nvmlDeviceGetMemoryInfo(h)
272
- mem_used_mb = m.used // 1024**2
273
- mem_total_mb = m.total // 1024**2
274
- telemetry_text = f"📊 System Status: GPU {u.gpu}% | Mem {mem_used_mb}/{mem_total_mb} MB"
275
  except Exception:
 
276
  telemetry_text = "📊 System Status: [Telemetry Error]"
277
 
278
- # Use markdown with a custom class for styling
279
  telemetry_placeholder.markdown(f'<div class="telemetry-box">{telemetry_text}</div>', unsafe_allow_html=True)
280
 
281
-
282
  # Initial telemetry update when the script starts
283
  update_telemetry()
284
 
@@ -286,59 +313,83 @@ update_telemetry()
286
  # --- Caching Model Loading (Core Logic Only) ---
287
  # Use st.cache_resource for heavy objects like models and tokenizers.
288
  # This function MUST NOT call Streamlit elements that affect the layout
289
- # or state outside of its own scope.
290
- @st.cache_resource(show_spinner=False) # Spinner handled manually
291
  def _load_model_and_tokenizer_cached(model_name: str, device: str, forced_model_type: str = None):
292
  """
293
- Loads the model and tokenizer. This function is cached and should
294
- contain minimal Streamlit calls to avoid caching issues.
295
  """
 
296
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
297
  is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
298
  detected_type = "Seq2Seq" if is_encoder_decoder else "Causal"
299
 
300
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
301
- # Ensure padding token is set for generation robustness
 
 
 
302
  if tokenizer.pad_token is None:
303
  if tokenizer.eos_token is not None:
304
  tokenizer.pad_token = tokenizer.eos_token
 
 
 
305
  else:
306
- # Fallback - adding tokens might require resizing model embeddings
307
- # which is complex and model-dependent. This is a basic attempt.
308
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
309
- tokenizer.pad_token = '[PAD]' # Set the attribute
310
- # Attempt to get the new pad token ID - may not work for all tokenizers
311
- try:
312
- tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
313
- except Exception:
314
- tokenizer.pad_token_id = None # Indicate failure to get ID
315
 
316
  # Determine the model class based on detection or forced selection
317
  actual_model_type = forced_model_type if forced_model_type != "Auto" else detected_type
 
318
 
319
- if actual_model_type == "Seq2Seq":
320
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config, trust_remote_code=True)
321
- elif actual_model_type == "Causal":
322
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True)
323
- else:
324
- raise ValueError(f"Unsupported model type selected: {actual_model_type}. Please select 'Auto', 'Causal', or 'Seq2Seq'.")
 
 
 
 
 
325
 
326
- model.to(device)
327
- model.eval() # Crucial for consistent inference behavior and disabling dropout etc.
 
 
 
 
328
 
329
- # Ensure return_dict_in_generate is True for structured outputs
330
- if not getattr(model.config, 'return_dict_in_generate', False):
331
- model.config.return_dict_in_generate = True
 
 
 
 
 
 
 
332
 
333
  return model, tokenizer, actual_model_type
334
 
335
  # --- Wrapper function to handle status reporting for cached loading ---
336
  def safe_load_model_with_status(model_name: str, device: str, forced_model_type: str = None):
337
  """
338
- Calls the cached loading function and handles Streamlit status updates.
 
339
  """
 
340
  status_text = f"🌐 Loading model '{model_name}' on device '{device}'..."
341
- # Use st.status here, defined outside the cached function
342
  with st.status(status_text, expanded=True) as status_box:
343
  status_box.write("Checking system status...")
344
  update_telemetry() # Update the separate telemetry box
@@ -352,9 +403,9 @@ def safe_load_model_with_status(model_name: str, device: str, forced_model_type:
352
  forced_model_type=forced_model_type
353
  )
354
 
355
- # Report padding token status if available
356
  if tokenizer and tokenizer.pad_token_id is None:
357
- status_box.warning(f"Tokenizer has no pad_token_id. Generation might fail for models requiring padding (e.g., batching).")
358
  elif tokenizer:
359
  status_box.write(f"Tokenizer pad_token_id set to {tokenizer.pad_token_id}.")
360
 
@@ -366,21 +417,81 @@ def safe_load_model_with_status(model_name: str, device: str, forced_model_type:
366
  except Exception as e:
367
  status_box.error(f"❌ Model loading failed.")
368
  update_telemetry() # Final telemetry update after error
369
- st.exception(e) # Display the full exception traceback
370
- # Clean up resources in case of failure before returning None
371
- # These are manual attempts; cache handles cleanup on its own state changes
372
- # but explicit cleanup is good practice on error paths.
373
- try:
374
- if 'model' in locals() and model is not None: del model
375
- except NameError: pass
376
- try:
377
- if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
378
- except NameError: pass
379
- if torch.cuda.is_available(): torch.cuda.empty_cache()
380
- gc.collect()
381
  return None, None, None # Return None on failure
382
 
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # --- Sidebar Configuration ---
385
  with st.sidebar:
386
  st.header("⚙️ Core Settings")
@@ -389,47 +500,64 @@ with st.sidebar:
389
  with st.expander("🧠 Model Configuration", expanded=True):
390
  model_name = st.text_input(
391
  "Hugging Face Model ID or Path",
392
- "ayjays132/NeuroReasoner-1-NR-1",
393
- help="Enter the model ID from huggingface.co or a local path."
394
  )
395
 
396
  # --- Dynamic Model Type Detection ---
397
- detected_type = "Unknown (Enter Model ID)"
398
- # Options match the strings used in the loading function
399
- model_type_options = ["Auto", "Causal", "Seq2Seq"]
400
- default_model_type_index = model_type_options.index("Auto")
401
-
402
  # Attempt to load config to detect type without caching (lightweight check)
 
 
 
403
  try:
404
  if model_name and model_name.strip(): # Only attempt if input is not empty
405
- initial_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
406
- is_encoder_decoder_initial = getattr(initial_config, "is_encoder_decoder", False)
407
- detected_type = "Seq2Seq" if is_encoder_decoder_initial else "Causal"
 
 
 
 
408
  else:
409
- detected_type = "Unknown (Enter Model ID)"
 
 
410
  except Exception:
411
- detected_type = "Unknown (Config Load Error)" # Indicate config load itself failed
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  forced_model_type = st.selectbox(
414
  "Architecture Type",
415
  model_type_options,
416
- index=default_model_type_index,
417
- help=f"Detected: {detected_type}. 'Auto' uses the detected type. Select manually if detection is incorrect or overridden."
418
  )
419
 
420
  # --- Device Selection ---
 
421
  available_devices = ["cpu"]
422
  if torch.cuda.is_available():
423
- available_devices.insert(0, "cuda") # Put cuda first if available
424
 
425
  device = st.selectbox(
426
  "Device",
427
  available_devices,
 
428
  help="Select the hardware device for computation (GPU recommended)."
429
  )
430
 
431
  st.markdown("""
432
- <small>💡 Changing model settings requires reloading the model.</small>
433
  """, unsafe_allow_html=True)
434
 
435
 
@@ -439,74 +567,71 @@ with st.sidebar:
439
  st.markdown("Define how the AI generates reasoning steps and answers.")
440
 
441
  with st.expander("Basic Parameters", expanded=True):
442
- # Finalized 'Number of Reasoning Chains' parameter
443
  num_chains = st.slider(
444
  "Number of Reasoning Chains",
445
  min_value=1,
446
- max_value=15, # Kept the higher max for more robustness
447
- value=5, # Kept the default of 5
448
- help="How many independent reasoning chains to generate for analyzing the problem. More chains can improve Self-Consistency but take longer."
449
  )
450
 
451
- # Finalized 'No-repeat Ngram Size' parameter
452
- no_repeat_ngram_size = st.slider( # Using the standard name for GenerationConfig
453
- "No-repeat Ngram Size",
454
- min_value=0,
455
- max_value=10,
456
- value=3,
457
- help="Avoids generating repeating sequences of N tokens. Set to 0 to disable."
458
- )
459
-
460
- # Self-Consistency checkbox remains
461
- self_consistency = st.checkbox(
462
  "Enable Self-Consistency Voting",
463
- value=True,
464
- help="When enabled, the system generates multiple chains and identifies the most common final answer as the consensus. Requires 'Number of Reasoning Chains' > 1."
465
  )
466
 
467
  # Conditional warning if Self-Consistency is on but num_chains is 1
468
- if self_consistency and num_chains <= 1:
469
- st.warning("Self-Consistency is most effective with 2 or more chains.") # Slightly rephrased warning
470
 
471
 
472
  # Advanced Parameters
473
  with st.expander("🧪 Advanced Sampling Parameters"):
 
474
  max_new_tokens = st.slider(
475
- "Max Tokens per Chain",
476
- 50, 2048, 768,
477
- help="Maximum number of new tokens to generate for *each* individual reasoning chain. Adjust based on complexity expected."
478
  )
479
  temperature = st.slider(
480
  "Temperature",
481
- 0.0, 2.0, 0.8,
482
- help="Controls the randomness of sampling. 0.0 is deterministic (greedy). Higher values increase diversity."
 
483
  )
484
  top_k = st.slider(
485
  "Top-k",
486
- 0, 100, 50,
487
  help="Filter to consider only the top_k most likely tokens at each step (0 disables). Used with sampling."
488
  )
489
  top_p = st.slider(
490
  "Top-p (Nucleus Sampling)",
491
- 0.0, 1.0, 0.95,
 
492
  help="Filter to consider tokens with cumulative probability below top_p (0.0 disables). Used with sampling."
493
  )
 
 
 
 
 
 
 
 
 
 
 
494
  do_sample = st.checkbox(
495
  "Enable Sampling",
496
- value=True,
497
- help="If checked, uses probabilistic sampling (controlled by Temperature, Top-k, Top-p). If unchecked, uses greedy decoding."
498
  )
499
  if not do_sample:
500
  st.info("Sampling disabled. Temperature, Top-k, and Top-p will be ignored.")
501
 
502
- no_repeat_ngram_size = st.slider(
503
- "No-repeat Ngram Size",
504
- 0, 10, 3,
505
- help="Avoids repeating sequences of N tokens. Set to 0 to disable."
506
- )
507
- # Optional: Add a seed for reproducibility if desired
508
- # generation_seed = st.number_input("Generation Seed (Optional)", value=-1, help="Set a positive integer for reproducible generation.")
509
-
510
 
511
  st.markdown("---") # Visual separator
512
 
@@ -528,15 +653,15 @@ with input_container:
528
  with prompt_col:
529
  prompt = st.text_area(
530
  "📝 Enter your query or problem:",
531
- height=150,
532
  placeholder="Example: If a train travels at 60 mph and a car at 40 mph, starting at the same time from cities 300 miles apart, how long until they meet? Think step-by-step.",
533
- key="user_prompt" # Added key for stability
534
  )
535
 
536
  with button_col:
537
- # Add some vertical space to align the button nicely
538
- st.markdown("<div style='height: 3.5rem;'></div>", unsafe_allow_html=True)
539
- run_button = st.button("✨ Generate Reasoning", use_container_width=True, key="generate_button") # Added key
540
 
541
  # Container for status updates and results
542
  results_container = st.container()
@@ -546,223 +671,227 @@ results_container = st.container()
546
  if run_button:
547
  if not prompt or not prompt.strip():
548
  results_container.warning("Please enter a prompt to begin generation.")
549
- st.stop() # Stop execution until prompt is entered
550
-
551
- # --- Prepare for Generation ---
552
- # Load model and tokenizer (handles caching internally with st.cache_resource
553
- # via safe_load_model_with_status which also reports status)
554
- # This happens only when the button is clicked and parameters might have changed
555
- model, tokenizer, loaded_model_type = safe_load_model_with_status(model_name, device, forced_model_type)
556
-
557
- if model is None or tokenizer is None:
558
- # Error was already shown by safe_load_model_with_status
559
- st.error("Model or tokenizer failed to load. Please check settings and traceback above.")
560
- st.stop() # Stop if loading failed
561
-
562
-
563
- # --- Configure Generation ---
564
- # Use a status box for ongoing generation process
565
- with results_container:
566
- st.markdown("---") # Separator before results
567
- generation_status = st.status("Preparing generation config...", expanded=True)
568
- update_telemetry() # Update telemetry while status is active
569
-
570
-
571
- try:
572
- # Build GenerationConfig based on sidebar parameters
573
- # num_return_sequences should match num_chains for the wrapper to process them
574
- gen_cfg = GenerationConfig(
575
- max_new_tokens=max_new_tokens,
576
- temperature=temperature,
577
- top_k=top_k,
578
- top_p=top_p,
579
- do_sample=do_sample,
580
- num_return_sequences=num_chains, # <--- Corrected: Use num_chains directly
581
- no_repeat_ngram_size=no_repeat_ngram_size,
582
- eos_token_id=tokenizer.eos_token_id,
583
- pad_token_id=tokenizer.pad_token_id,
584
- return_dict_in_generate=True,
585
- output_scores=False,
586
- output_attentions=False,
587
- output_hidden_states=False,
588
- use_cache=True,
589
- )
590
- generation_status.write(f"Generation parameters set: {gen_cfg.to_dict()}")
591
- update_telemetry()
592
-
593
- cfg = GenerationConfig(
594
- max_new_tokens=max_new_tokens,
595
- temperature=temperature,
596
- top_k=top_k,
597
- top_p=top_p,
598
- do_sample=True,
599
- num_return_sequences=num_chains,
600
- no_repeat_ngram_size=no_repeat_ngram_size,
601
- eos_token_id=tokenizer.eos_token_id,
602
- pad_token_id=tokenizer.pad_token_id
603
- )
604
- except Exception as e:
605
- generation_status.error(f"❌ Failed to create GenerationConfig: {e}")
606
- st.exception(e)
607
- st.stop()
608
-
609
- # --- Instantiate Wrapper ---
610
- try:
611
- generation_status.write("Initializing Chain-of-Thought wrapper...")
612
- # Pass the configured generation config to the wrapper
613
- # The wrapper should internally use num_return_sequences from gen_cfg
614
- cot_wrapper = ChainOfThoughtWrapper(
615
- model=model,
616
- tokenizer=tokenizer,
617
- generation_config=cfg,
618
- device=device,
619
- self_consistency=self_consistency,
620
- consistency_rounds=(num_chains if self_consistency else 1)
621
- )
622
- generation_status.write("Wrapper initialized.")
623
- update_telemetry()
624
-
625
- except Exception as e:
626
- generation_status.error(f"❌ Failed to initialize CoT wrapper: {e}")
627
- st.exception(e)
628
- st.stop()
629
-
630
- # --- Tokenize Input ---
631
- try:
632
- generation_status.write("Tokenizing input prompt...")
633
- # Use model_max_length or a reasonable cap for input length
634
- max_input_length = tokenizer.model_max_length
635
- if max_input_length is None or max_input_length > 4096: # Cap input length if tokenizer reports None or very large
636
- max_input_length = 4096
637
- if tokenizer.model_max_length is None:
638
- generation_status.warning(f"Tokenizer has no model_max_length, capping input to {max_input_length}.")
639
-
640
-
641
- enc = tokenizer(
642
- prompt,
643
- return_tensors='pt',
644
- padding='longest', # Pad to the longest sequence in the batch (batch size is 1 here)
645
- truncation=True,
646
- max_length=max_input_length, # Use a proper max length for the input
647
- ).to(device)
648
- generation_status.write(f"Input token length: {enc['input_ids'].shape[1]}")
649
- update_telemetry()
650
-
651
- except Exception as e:
652
- generation_status.error(f"❌ Tokenization failed: {e}")
653
- st.exception(e)
654
- st.stop()
655
-
656
- # --- Generate ---
657
- generation_status.update(label=f"⏳ Generating {num_chains} reasoning chains...", state="running")
658
- start_time = time.time()
659
 
660
- try:
661
- # Call the wrapper's generate method
662
- # It should handle the loop for multiple chains and self-consistency internally
663
- outputs = cot_wrapper.generate(
664
- input_ids=enc['input_ids'],
665
- attention_mask=enc['attention_mask'],
666
- # Pass any other necessary arguments to your wrapper's generate method
667
- )
668
- # Expected `outputs` dict structure: {'full_texts': [...], 'reasoning_steps': [...], 'final_answers': [...], 'consensus_answer': '...'}
669
- # The wrapper should handle extracting steps/answers if needed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
- except Exception as e:
672
- generation_status.error(f"❌ Generation failed: {e}")
673
- st.exception(e)
674
- # Clean up resources after potential OOM or other errors
675
- if torch.cuda.is_available(): torch.cuda.empty_cache()
676
- gc.collect() # Python garbage collection
677
- st.stop()
678
-
679
- elapsed_time = time.time() - start_time
680
- generation_status.update(label=f"✨ Generation complete in {elapsed_time:.2f}s", state="complete")
681
- update_telemetry() # Final telemetry update after successful generation
682
-
683
- # --- Display Results ---
684
- with results_container:
685
- st.markdown("## 📚 Reasoning Output")
686
-
687
- # Display Self-Consistency Consensus first if enabled and results are available
688
- if self_consistency and outputs and 'consensus_answer' in outputs and outputs.get('final_answers'):
689
- consensus = outputs.get('consensus_answer')
690
- answers = outputs.get('final_answers', [])
691
-
692
- st.markdown('<div class="consensus-answer">', unsafe_allow_html=True)
693
- st.write("💡 **Consensus Answer (Self-Consistency):**")
694
- st.write(consensus if consensus else "[Could not determine consensus]")
695
- st.markdown('</div>', unsafe_allow_html=True)
696
-
697
- if answers and len(answers) > 1: # Only show distribution if more than one answer was found
698
- st.markdown("###### Answer Distribution:")
699
- answer_counts = Counter(answers)
700
- # Display sorted distribution
701
- for ans, count in answer_counts.most_common():
702
- st.write(f"- '{ans}' ({count} {'vote' if count == 1 else 'votes'})")
703
- st.markdown("---") # Separator
704
-
705
-
706
- # Display individual chains
707
- full_texts = outputs.get('full_texts', [])
708
- reasoning_steps = outputs.get('reasoning_steps', [])
709
- final_answers = outputs.get('final_answers', [])
710
-
711
- if not full_texts:
712
- st.warning("No reasoning chains were generated.")
713
- else:
714
- st.markdown(f"### Individual Chains ({len(full_texts)} generated)")
715
- # Iterate and display each chain in an expander
716
- # Ensure lists are iterable, even if empty
717
- full_texts = full_texts if isinstance(full_texts, list) else []
718
- reasoning_steps = reasoning_steps if isinstance(reasoning_steps, list) else []
719
- final_answers = final_answers if isinstance(final_answers, list) else []
720
-
721
- # Pad lists to the same length in case the wrapper returned inconsistent outputs
722
- max_len_outputs = max(len(full_texts), len(reasoning_steps), len(final_answers))
723
- full_texts.extend(["[N/A - Generation Failed for this chain]"] * (max_len_outputs - len(full_texts)))
724
- reasoning_steps.extend([[]] * (max_len_outputs - len(reasoning_steps)))
725
- final_answers.extend(["[N/A]"] * (max_len_outputs - len(final_answers)))
726
-
727
-
728
- for idx, (text, steps, ans) in enumerate(zip(full_texts, reasoning_steps, final_answers), 1):
729
- # Use try-except just in case a single chain output is malformed
730
  try:
731
- # Expander for each chain, starting collapsed
732
- with st.expander(f"Chain {idx}", expanded=False):
733
- # Use custom class for styling the label
734
- st.markdown('<div class="output-label">Full Generated Text:</div>', unsafe_allow_html=True)
735
- # Use custom class for styling the text area background
736
- st.text_area(f"chain_text_area_{idx}", text, height=250, label_visibility="collapsed", help="The complete generated output for this chain.")
737
-
738
- if steps and isinstance(steps, list):
739
- st.markdown('<div class="output-label">Reasoning Steps:</div>', unsafe_allow_html=True)
740
- # Display steps as a list
741
- if steps:
742
- for i, step in enumerate(steps, 1):
743
- if isinstance(step, str) and step.strip():
744
- st.write(f"**Step {i}:** {step.strip()}")
745
- elif not isinstance(step, str):
746
- st.warning(f"Step {i} has invalid format.")
747
- else:
748
- st.info("No specific steps were extracted for this chain.")
749
-
750
-
751
- st.markdown('<div class="output-label">Final Answer:</div>', unsafe_allow_html=True)
752
- st.write(f"**{ans if ans else '[No answer extracted]'}**")
753
-
754
- # Optional: Add a separator between chain sections
755
- st.markdown("---", help="End of Chain details.")
756
-
757
- except Exception as chain_e:
758
- st.error(f"Error displaying Chain {idx}: {chain_e}")
759
- st.exception(chain_e)
760
-
761
-
762
- st.markdown("---") # Final separator
763
- st.info("Generation process concluded. Review the chains above.")
764
-
765
- # Clean up GPU memory after generation is complete and results are displayed
766
- if torch.cuda.is_available():
767
- torch.cuda.empty_cache()
768
- gc.collect() # Python garbage collection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  A premium Streamlit app for step-by-step reasoning
6
  across any Hugging Face model (causal or seq2seq).
7
  Featuring a dark theme, model-type detection, self-consistency
8
+ sampling & voting, robust handling, and GPU telemetry.
9
  """
10
  import os
11
  import time
12
+ import re # Needed for answer normalization
13
  import streamlit as st
14
  import torch
15
  import pynvml # For GPU telemetry
16
+ import numpy as np # Imported, but currently unused in core logic
17
  from transformers import (
18
  AutoConfig,
19
  AutoTokenizer,
 
24
  )
25
  from collections import Counter # For self-consistency voting
26
  import gc # Import garbage collector
27
+ from typing import Any, Dict, List, Optional, Tuple, Union
28
+
29
+ # --- Import the Enhanced ChainOfThoughtWrapper ---
30
+ # Assuming chain_of_thought_wrapper.py is in the same directory.
31
+ # The wrapper is expected to:
32
+ # 1. Accept model, tokenizer, GenerationConfig (base), device, etc. in __init__.
33
+ # 2. Have a .generate() method that takes input_text (str), GenerationConfig (overrides),
34
+ # and crucially, num_return_sequences (int) to generate multiple chains efficiently.
35
+ # 3. Return a dictionary with keys 'full_texts', 'reasoning_steps', 'final_answers' (lists).
36
  try:
37
  from chain_of_thought_wrapper import ChainOfThoughtWrapper
38
  except ImportError:
39
+ st.error("Error: chain_of_thought_wrapper.py not found. Please ensure the enhanced wrapper script is in the same directory.")
40
+ st.stop() # Halt execution if the wrapper is not found
41
 
42
+ # --- Logging Setup for GUI ---
43
+ # Use Streamlit's built-in logging or configure a separate logger
44
+ # For this example, we'll keep it simple and rely mostly on st.status and st.exception
45
+ # if needed, a more detailed logger could be configured here.
46
 
47
  # --- Page Configuration ---
48
  st.set_page_config(
49
  page_title="🧠 NeuroReasoner CoT GUI",
50
  page_icon="🧠",
51
+ layout="wide", # Use wide layout
52
+ initial_sidebar_state="expanded", # Sidebar open by default
53
  menu_items={
54
+ 'Get Help': 'https://github.com/ayjays132/NeuroReasoner', # Example repo link
55
+ 'Report a bug': "https://github.com/ayjays132/NeuroReasoner/issues", # Example repo issues link
56
  'About': """
57
  **NeuroReasoner Chain-of-Thought GUI**
58
+ An open-source interface powered by Hugging Face models and the enhanced NeuroReasoner wrapper.
59
  Explore step-by-step reasoning with various language models.
60
+ \n\n**Features:** Dark Theme, GPU Telemetry, Model Caching, Self-Consistency Voting,
61
+ Robust Generation Parameters, Support for Causal and Seq2Seq models.
62
  """
63
  }
64
  )
65
 
66
  # --- Dark Theme CSS ---
67
+ # Comprehensive CSS for a professional dark theme inspired by VS Code.
68
  st.markdown("""
69
  <style>
70
  /* Overall Page Background & Text (Dark Theme) */
71
+ /* Target the main container and the root app div */
72
+ .stApp {
73
  background-color: #1E1E1E; /* Dark grey background */
74
  color: #D4D4D4; /* Light grey text */
75
  font-family: 'Segoe UI', Roboto, Arial, sans-serif;
76
  }
 
 
 
 
77
 
78
  /* Sidebar Styling */
79
  .stSidebar {
80
  background-color: #2D2D2D; /* Slightly lighter dark grey for sidebar */
81
  padding: 2rem 1rem;
82
  border-right: 1px solid #3E3E3E; /* Subtle border */
83
+ color: #D4D4D4; /* Ensure text in sidebar is light */
84
  }
85
  .stSidebar h1, .stSidebar h2, .stSidebar h3 {
86
+ color: #569CD6 !important; /* Visual Studio Code blue for sidebar headers */
87
  }
88
  .stSidebar label {
89
  color: #D4D4D4 !important; /* Ensure sidebar labels are visible */
 
91
 
92
 
93
  /* Main Content Area */
94
+ /* No specific background needed here, .stApp covers it */
 
 
95
 
96
  /* Titles and Headers */
97
  h1, h2, h3, h4, h5, h6 {
 
114
  font-weight: bold;
115
  transition: background-color 0.2s ease, transform 0.1s ease;
116
  box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
117
+ margin-top: 1.65rem; /* Add top margin to align with text area */
118
  }
119
  .stButton>button:hover {
120
  background-color: #27633A; /* Lighter green on hover */
 
128
 
129
 
130
  /* Text areas and inputs */
131
+ /* Target specific classes used by Streamlit for input/text areas */
132
+ div[data-baseweb="textarea"] textarea,
133
+ div[data-baseweb="input"] input {
134
  border: 1px solid #3E3E3E; /* Dark border */
135
  border-radius: 0.4rem;
136
  padding: 0.75rem;
 
139
  color: #D4D4D4; /* Light text */
140
  box-shadow: inset 1px 1px 3px rgba(0, 0, 0, 0.2);
141
  }
142
+ div[data-baseweb="textarea"] label,
143
+ div[data-baseweb="input"] label,
144
+ .stSlider label, .stSelectbox label, .stCheckbox label {
145
  font-weight: bold;
146
  color: #9CDCFE !important; /* Light blue labels */
147
  margin-bottom: 0.5rem;
148
  display: block;
149
  }
150
+ /* Streamlit status box styling - Target the main container and its contents */
151
+ .st-emotion-cache-vj1l9j { /* This class might change with Streamlit versions */
152
  background-color: #2D2D2D; /* Match sidebar background */
153
  border: 1px solid #3E3E3E;
154
  border-radius: 0.5rem;
155
  padding: 1rem;
156
  margin-bottom: 1rem;
157
  }
158
+ .st-emotion-cache-vj1l9j .stMarkdown p,
159
+ .st-emotion-cache-vj1l9j .stAlert { /* Style text and alerts inside status */
160
+ color: #D4D4D4 !important;
161
+ background-color: transparent !important; /* Don't want alert backgrounds inside status */
162
+ border: none !important; /* No borders for alerts inside status */
163
+ padding: 0.5rem 0 !important; /* Adjust padding */
164
  }
165
 
166
 
 
171
  padding: 1rem;
172
  font-size: 1rem;
173
  border-left: 5px solid transparent; /* Base style */
174
+ color: #D4D4D4; /* Default text color for alerts */
175
  }
176
+ .stAlert.stAlert-info { border-left-color: #569CD6; background-color: #2A3E52; } /* Dark blue info */
177
+ .stAlert.stAlert-success { border-left-color: #4EC9B0; background-color: #28403A; } /* Dark teal success */
178
+ .stAlert.stAlert-warning { border-left-color: #DCDCAA; background-color: #454032; } /* Dark yellow warning */
179
+ .stAlert.stAlert-error { border-left-color: #F44747; background-color: #4A3030; } /* Dark red error */
180
 
181
 
182
  /* Expander styling */
 
205
  margin-top: 0;
206
  color: #D4D4D4;
207
  }
208
+ .streamlit-expanderContent .stMarkdown p {
209
+ color: #D4D4D4 !important; /* Ensure text inside expanders is light */
210
+ }
211
+
212
 
213
  /* Labels for the output text areas */
214
  .output-label {
 
221
  }
222
 
223
  /* Custom class for output text areas to differentiate from input */
224
+ /* Need to target the specific Streamlit internal class for the text area */
225
+ .output-text-area div[data-baseweb="textarea"] textarea {
226
  background-color: #1E1E1E; /* Even darker background for outputs */
227
  border: 1px solid #3E3E3E;
228
  border-radius: 0.4rem;
 
256
  font-weight: bold;
257
  }
258
  .consensus-answer strong {
259
+ color: #4EC9B0 !important; /* Teal for "Consensus Answer" label */
 
 
 
260
  }
261
+ .consensus-answer p {
262
+ color: #D4D4D4 !important; /* Ensure the answer text is light */
263
+ margin: 0 !important; /* Remove default paragraph margins */
264
+ padding: 0 !important; /* Remove default paragraph padding */
265
+ }
266
 
267
 
268
  </style>
 
270
 
271
 
272
  # --- GPU Telemetry Setup ---
273
+ # Initialize NVML for GPU monitoring if available
274
  try:
275
  pynvml.nvmlInit()
276
  GPU_AVAILABLE = True
277
+ # Get the number of devices to pick the first one (index 0)
278
+ GPU_COUNT = pynvml.nvmlDeviceGetCount()
279
+ if GPU_COUNT == 0:
280
+ GPU_AVAILABLE = False
281
+ st.warning("NVML initialized but no NVIDIA GPUs found.")
282
  except Exception:
283
  GPU_AVAILABLE = False
284
+ # st.info("NVIDIA Management Library (pynvml) not found or failed to initialize. GPU telemetry disabled.")
285
 
286
  # Use st.empty to hold the telemetry status text, defined *outside* cached functions
287
  telemetry_placeholder = st.empty()
288
 
289
  def update_telemetry():
290
  """Updates the telemetry display in the dedicated placeholder."""
291
+ telemetry_text = "📊 System Status: [Initializing...]"
292
  if not GPU_AVAILABLE or not torch.cuda.is_available():
293
  telemetry_text = "📊 System Status: [No GPU Available]"
294
  else:
295
  try:
296
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Use the first GPU
297
+ utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
298
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
299
+ mem_used_mb = memory.used // 1024**2
300
+ mem_total_mb = memory.total // 1024**2
301
+ telemetry_text = f"📊 System Status: GPU 0: {utilization.gpu}% | Mem {mem_used_mb}/{mem_total_mb} MB"
302
  except Exception:
303
+ # If NVML fails after initialization, report error
304
  telemetry_text = "📊 System Status: [Telemetry Error]"
305
 
306
+ # Use markdown with a custom class for styling the container
307
  telemetry_placeholder.markdown(f'<div class="telemetry-box">{telemetry_text}</div>', unsafe_allow_html=True)
308
 
 
309
  # Initial telemetry update when the script starts
310
  update_telemetry()
311
 
 
313
  # --- Caching Model Loading (Core Logic Only) ---
314
  # Use st.cache_resource for heavy objects like models and tokenizers.
315
  # This function MUST NOT call Streamlit elements that affect the layout
316
+ # or state outside of its own scope (except for the final return value).
317
+ @st.cache_resource(show_spinner=False) # Spinner handled manually in safe_load_model_with_status
318
  def _load_model_and_tokenizer_cached(model_name: str, device: str, forced_model_type: str = None):
319
  """
320
+ Loads the model and tokenizer. This function is cached by Streamlit.
321
+ It should perform resource-intensive loading only.
322
  """
323
+ # Use low_cpu_mem_usage=True to reduce RAM usage during loading
324
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
325
  is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
326
  detected_type = "Seq2Seq" if is_encoder_decoder else "Causal"
327
 
328
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
329
+
330
+ # Ensure padding token is set for generation robustness, especially for batching (num_return_sequences)
331
+ # This mirrors the logic in the wrapper's __init__ but is good to do here too
332
+ # before the model is potentially loaded with a different vocab size.
333
  if tokenizer.pad_token is None:
334
  if tokenizer.eos_token is not None:
335
  tokenizer.pad_token = tokenizer.eos_token
336
+ # Set the ID explicitly as well
337
+ tokenizer.pad_token_id = tokenizer.eos_token_id
338
+ # logger.warning(f"Tokenizer pad_token is None, using eos_token '{tokenizer.eos_token}' as pad_token.")
339
  else:
340
+ # Fallback: Add a new pad token if neither eos nor pad exists.
341
+ # The wrapper's init will handle resizing embeddings if possible.
342
+ # logger.warning("Tokenizer has no pad_token and no eos_token. Adding a new [PAD] token.")
343
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
344
+ # Need to get the ID for the new token
345
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
346
+ # logger.info(f"Added new [PAD] token with ID {tokenizer.pad_token_id}.")
347
+
 
348
 
349
  # Determine the model class based on detection or forced selection
350
  actual_model_type = forced_model_type if forced_model_type != "Auto" else detected_type
351
+ model = None # Initialize model to None
352
 
353
+ try:
354
+ if actual_model_type == "Seq2Seq":
355
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config, trust_remote_code=True)
356
+ elif actual_model_type == "Causal":
357
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True)
358
+ else:
359
+ raise ValueError(f"Unsupported model type selected: {actual_model_type}. Please select 'Auto', 'Causal', or 'Seq2Seq'.")
360
+
361
+ # Move model to device and set to eval mode
362
+ model.to(device)
363
+ model.eval() # Crucial for consistent inference behavior and disabling dropout etc.
364
 
365
+ # Ensure return_dict_in_generate is True for structured outputs (needed by wrapper for scores/sequences)
366
+ if not getattr(model.config, 'return_dict_in_generate', False):
367
+ model.config.return_dict_in_generate = True
368
+ # Request output_scores by default for potential future CISC use in the GUI voter
369
+ if not getattr(model.config, 'output_scores', False):
370
+ model.config.output_scores = True
371
 
372
+ # The wrapper's __init__ will perform its own pad token handling and embedding resizing check.
373
+ # We ensure the tokenizer passed to it has a pad_token_id.
374
+
375
+ except Exception as e:
376
+ # Clean up resources if model loading failed
377
+ if model is not None: del model
378
+ if tokenizer is not None: del tokenizer
379
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
380
+ gc.collect()
381
+ raise e # Re-raise the exception for the caller to handle status updates
382
 
383
  return model, tokenizer, actual_model_type
384
 
385
  # --- Wrapper function to handle status reporting for cached loading ---
386
  def safe_load_model_with_status(model_name: str, device: str, forced_model_type: str = None):
387
  """
388
+ Calls the cached loading function (_load_model_and_tokenizer_cached)
389
+ and handles Streamlit status updates and error reporting.
390
  """
391
+ # Use st.status here, defined outside the cached function, for live updates
392
  status_text = f"🌐 Loading model '{model_name}' on device '{device}'..."
 
393
  with st.status(status_text, expanded=True) as status_box:
394
  status_box.write("Checking system status...")
395
  update_telemetry() # Update the separate telemetry box
 
403
  forced_model_type=forced_model_type
404
  )
405
 
406
+ # Report padding token status after loading
407
  if tokenizer and tokenizer.pad_token_id is None:
408
+ status_box.warning(f"Tokenizer has no pad_token_id. Batch generation (Self-Consistency) might be unstable.")
409
  elif tokenizer:
410
  status_box.write(f"Tokenizer pad_token_id set to {tokenizer.pad_token_id}.")
411
 
 
417
  except Exception as e:
418
  status_box.error(f"❌ Model loading failed.")
419
  update_telemetry() # Final telemetry update after error
420
+ st.exception(e) # Display the full exception traceback within the status box
421
+ # No need for manual cleanup here, as the exception in the cached function
422
+ # should have triggered cleanup within that function, and Streamlit's
423
+ # cache resource management handles state on failure.
 
 
 
 
 
 
 
 
424
  return None, None, None # Return None on failure
425
 
426
 
427
+ # --- Self-Consistency Voting Logic ---
428
+ def normalize_answer(answer: str) -> str:
429
+ """
430
+ Normalizes a string answer for robust comparison during voting.
431
+ - Converts to lowercase.
432
+ - Strips leading/trailing whitespace.
433
+ - Removes common punctuation.
434
+ - Can be extended with more sophisticated normalization (e.g., number words to digits).
435
+ """
436
+ if not isinstance(answer, str):
437
+ return "" # Handle non-string inputs
438
+
439
+ # Simple normalization: lowercase, strip whitespace, remove common punctuation
440
+ normalized = answer.lower().strip()
441
+ # Remove common trailing characters like periods, commas, etc.
442
+ normalized = re.sub(r'[.,!?;:]+$', '', normalized).strip()
443
+ # Remove common leading "Answer: " or similar preambles (case-insensitive)
444
+ normalized = re.sub(r'^\s*(?:the answer is|result|output)\s*[:\-]?\s*', '', normalized, flags=re.IGNORECASE).strip()
445
+ # Add more normalization rules if needed (e.g., handling "forty two" vs "42")
446
+
447
+ return normalized
448
+
449
+ def perform_self_consistency_voting(final_answers: List[str]) -> Tuple[Optional[str], Dict[str, int]]:
450
+ """
451
+ Performs simple majority voting on a list of final answers.
452
+ Filters out empty answers and normalizes them before voting.
453
+
454
+ Args:
455
+ final_answers (List[str]): A list of raw final answer strings from the wrapper.
456
+
457
+ Returns:
458
+ Tuple[Optional[str], Dict[str, int]]: A tuple containing:
459
+ - The winning (most common) normalized answer, or None if no valid answers.
460
+ - A dictionary mapping normalized answers to their vote counts.
461
+ """
462
+ if not final_answers:
463
+ return None, {}
464
+
465
+ # 1. Filter out empty or non-string answers
466
+ valid_answers = [ans for ans in final_answers if isinstance(ans, str) and ans.strip()]
467
+
468
+ if not valid_answers:
469
+ return None, {}
470
+
471
+ # 2. Normalize answers
472
+ normalized_answers = [normalize_answer(ans) for ans in valid_answers]
473
+ # Filter out answers that became empty after normalization
474
+ normalized_answers = [ans for ans in normalized_answers if ans.strip()]
475
+
476
+ if not normalized_answers:
477
+ return None, {}
478
+
479
+
480
+ # 3. Perform majority voting
481
+ answer_counts = Counter(normalized_answers)
482
+
483
+ # 4. Determine the consensus answer
484
+ # most_common(1) returns a list like [('answer', count)]
485
+ most_common_item = answer_counts.most_common(1)
486
+
487
+ if most_common_item:
488
+ consensus_answer = most_common_item[0][0]
489
+ return consensus_answer, dict(answer_counts)
490
+ else:
491
+ # This case should ideally not happen if normalized_answers is not empty
492
+ return None, dict(answer_counts)
493
+
494
+
495
  # --- Sidebar Configuration ---
496
  with st.sidebar:
497
  st.header("⚙️ Core Settings")
 
500
  with st.expander("🧠 Model Configuration", expanded=True):
501
  model_name = st.text_input(
502
  "Hugging Face Model ID or Path",
503
+ "ayjays132/NeuroReasoner-1-NR-1", # Default model
504
+ help="Enter the model ID from huggingface.co or a local path. Changing this requires reloading."
505
  )
506
 
507
  # --- Dynamic Model Type Detection ---
 
 
 
 
 
508
  # Attempt to load config to detect type without caching (lightweight check)
509
+ # This provides immediate feedback on the likely model type.
510
+ detected_type_display = "Unknown (Enter Model ID)"
511
+ initial_config = None
512
  try:
513
  if model_name and model_name.strip(): # Only attempt if input is not empty
514
+ with st.spinner("Detecting model type..."): # Small spinner for detection
515
+ # Use from_pretrained without loading weights (low_cpu_mem_usage=True helps)
516
+ initial_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True)
517
+ is_encoder_decoder_initial = getattr(initial_config, "is_encoder_decoder", False)
518
+ detected_type_display = "Seq2Seq" if is_encoder_decoder_initial else "Causal"
519
+ # Store the actual detected type for use in the selectbox default index
520
+ actual_detected_type_for_index = "Seq2Seq" if is_encoder_decoder_initial else "Causal"
521
  else:
522
+ detected_type_display = "Unknown (Enter Model ID)"
523
+ actual_detected_type_for_index = "Auto" # Default to Auto index if no model name
524
+
525
  except Exception:
526
+ detected_type_display = "Unknown (Config Load Error)" # Indicate config load itself failed
527
+ actual_detected_type_for_index = "Auto" # Default to Auto index if config load fails
528
+
529
+
530
+ # Options match the strings used in the loading function
531
+ model_type_options = ["Auto", "Causal", "Seq2Seq"]
532
+ # Set the default index based on the detected type, falling back to "Auto"
533
+ try:
534
+ default_model_type_index = model_type_options.index(actual_detected_type_for_index) if actual_detected_type_for_index in model_type_options else 0 # Default to Auto (index 0)
535
+ except ValueError:
536
+ default_model_type_index = 0 # Should not happen if logic is correct, but safety fallback
537
+
538
 
539
  forced_model_type = st.selectbox(
540
  "Architecture Type",
541
  model_type_options,
542
+ index=default_model_type_index, # Use the detected type as the default selection
543
+ help=f"Detected: {detected_type_display}. 'Auto' uses the detected type. Select manually if detection is incorrect or overridden."
544
  )
545
 
546
  # --- Device Selection ---
547
+ # List available devices, prioritizing CUDA if available
548
  available_devices = ["cpu"]
549
  if torch.cuda.is_available():
550
+ available_devices.insert(0, "cuda") # Put cuda first if available
551
 
552
  device = st.selectbox(
553
  "Device",
554
  available_devices,
555
+ index=(0 if "cuda" in available_devices else 0), # Default to cuda if available, else cpu
556
  help="Select the hardware device for computation (GPU recommended)."
557
  )
558
 
559
  st.markdown("""
560
+ <small>💡 Changing model settings requires reloading the model.</small>
561
  """, unsafe_allow_html=True)
562
 
563
 
 
567
  st.markdown("Define how the AI generates reasoning steps and answers.")
568
 
569
  with st.expander("Basic Parameters", expanded=True):
570
+ # Number of Reasoning Chains slider
571
  num_chains = st.slider(
572
  "Number of Reasoning Chains",
573
  min_value=1,
574
+ max_value=15, # Allow generating up to 15 chains
575
+ value=5, # Default to 5 chains for a good balance
576
+ help="How many independent reasoning chains to generate for analyzing the problem. More chains can improve Self-Consistency but take longer and use more memory."
577
  )
578
 
579
+ # Self-Consistency checkbox
580
+ self_consistency_enabled_gui = st.checkbox(
 
 
 
 
 
 
 
 
 
581
  "Enable Self-Consistency Voting",
582
+ value=True, # Default to enabled
583
+ help="When enabled, the system generates multiple chains and identifies the most common final answer as the consensus via majority voting. Requires 'Number of Reasoning Chains' > 1."
584
  )
585
 
586
  # Conditional warning if Self-Consistency is on but num_chains is 1
587
+ if self_consistency_enabled_gui and num_chains <= 1:
588
+ st.warning("Self-Consistency voting is most effective with 2 or more chains.")
589
 
590
 
591
  # Advanced Parameters
592
  with st.expander("🧪 Advanced Sampling Parameters"):
593
+ # Using standard parameter names from Hugging Face GenerationConfig
594
  max_new_tokens = st.slider(
595
+ "Max New Tokens per Chain",
596
+ 50, 2048, 768, # Min, Max, Default
597
+ help="Maximum number of new tokens to generate for *each* individual reasoning chain. Adjust based on expected reasoning complexity and answer length."
598
  )
599
  temperature = st.slider(
600
  "Temperature",
601
+ 0.0, 2.0, 0.8, # Min, Max, Default
602
+ step=0.05, # Allow finer control
603
+ help="Controls the randomness of sampling. 0.0 is deterministic (greedy). Higher values increase diversity in reasoning paths."
604
  )
605
  top_k = st.slider(
606
  "Top-k",
607
+ 0, 200, 50, # Min, Max, Default (increased max k)
608
  help="Filter to consider only the top_k most likely tokens at each step (0 disables). Used with sampling."
609
  )
610
  top_p = st.slider(
611
  "Top-p (Nucleus Sampling)",
612
+ 0.0, 1.0, 0.95, # Min, Max, Default
613
+ step=0.01, # Allow finer control
614
  help="Filter to consider tokens with cumulative probability below top_p (0.0 disables). Used with sampling."
615
  )
616
+ repetition_penalty = st.slider(
617
+ "Repetition Penalty",
618
+ 1.0, 2.0, 1.1, # Min, Max, Default
619
+ step=0.05, # Allow finer control
620
+ help="Penalizes repeated tokens or sequences. Higher values reduce repetition in the output."
621
+ )
622
+ no_repeat_ngram_size = st.slider(
623
+ "No-repeat Ngram Size",
624
+ 0, 10, 0, # Min, Max, Default (changed default to 0, often less needed with repetition penalty)
625
+ help="Avoids repeating sequences of N tokens. Set to 0 to disable. Can help prevent loops in reasoning."
626
+ )
627
  do_sample = st.checkbox(
628
  "Enable Sampling",
629
+ value=True, # Default to enabled
630
+ help="If checked, uses probabilistic sampling (controlled by Temperature, Top-k, Top-p). If unchecked, uses greedy decoding (deterministic)."
631
  )
632
  if not do_sample:
633
  st.info("Sampling disabled. Temperature, Top-k, and Top-p will be ignored.")
634
 
 
 
 
 
 
 
 
 
635
 
636
  st.markdown("---") # Visual separator
637
 
 
653
  with prompt_col:
654
  prompt = st.text_area(
655
  "📝 Enter your query or problem:",
656
+ height=180, # Increased height for better input experience
657
  placeholder="Example: If a train travels at 60 mph and a car at 40 mph, starting at the same time from cities 300 miles apart, how long until they meet? Think step-by-step.",
658
+ key="user_prompt" # Unique key for the widget
659
  )
660
 
661
  with button_col:
662
+ # Add some vertical space to align the button nicely with the text area
663
+ st.markdown("<div style='height: 3.25rem;'></div>", unsafe_allow_html=True) # Adjusted height
664
+ run_button = st.button("✨ Generate Reasoning", use_container_width=True, key="generate_button") # Unique key
665
 
666
  # Container for status updates and results
667
  results_container = st.container()
 
671
  if run_button:
672
  if not prompt or not prompt.strip():
673
  results_container.warning("Please enter a prompt to begin generation.")
674
+ # No need to st.stop() here, warning is sufficient
675
+ else:
676
+ # --- Prepare for Generation ---
677
+ # Load model and tokenizer (handles caching internally via safe_load_model_with_status)
678
+ # This happens only when the button is clicked and parameters might have changed
679
+ model, tokenizer, loaded_model_type = safe_load_model_with_status(model_name, device, forced_model_type)
680
+
681
+ if model is None or tokenizer is None:
682
+ # Error was already shown by safe_load_model_with_status
683
+ results_container.error("Model or tokenizer failed to load. Please check settings and traceback above.")
684
+ # No need to st.stop() here, error message is displayed
685
+
686
+ else: # Model and tokenizer loaded successfully
687
+ # --- Configure GenerationConfig ---
688
+ # Build the GenerationConfig object from sidebar parameters
689
+ # This config will be passed to the wrapper's __init__ as the base config
690
+ # and to the wrapper's generate() call as overrides.
691
+ # The wrapper's generate method will ultimately use these settings.
692
+ try:
693
+ # Base config for the wrapper's __init__ - defines default behavior
694
+ base_gen_config = GenerationConfig(
695
+ max_new_tokens=max_new_tokens, # Use sidebar value
696
+ temperature=temperature, # Use sidebar value
697
+ top_k=top_k, # Use sidebar value
698
+ top_p=top_p, # Use sidebar value
699
+ do_sample=do_sample, # Use sidebar value
700
+ repetition_penalty=repetition_penalty, # Use sidebar value
701
+ no_repeat_ngram_size=no_repeat_ngram_size, # Use sidebar value
702
+ eos_token_id=tokenizer.eos_token_id, # Always pass eos_token_id from tokenizer
703
+ pad_token_id=tokenizer.pad_token_id, # Always pass pad_token_id from tokenizer
704
+ # Other parameters like num_beams, etc., could be added here if exposed in sidebar
705
+ )
706
+
707
+
708
+ except Exception as e:
709
+ results_container.error(f"❌ Failed to create GenerationConfig from parameters: {e}")
710
+ st.exception(e)
711
+ st.stop() # Stop if config creation fails
712
+
713
+ # --- Instantiate Wrapper ---
714
+ # Use a status box for ongoing generation process
715
+ with results_container:
716
+ st.markdown("## ⏳ Generation Progress") # Use a clear header for the status section
717
+ generation_status = st.status("Initializing Chain-of-Thought wrapper...", expanded=True)
718
+ update_telemetry() # Update telemetry while status is active
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
+ try:
721
+ # Pass the loaded model, tokenizer, device, and relevant settings.
722
+ # The wrapper uses the generation_config passed here as its base defaults.
723
+ # Self-consistency settings from GUI are passed to wrapper's init attributes.
724
+ cot_wrapper = ChainOfThoughtWrapper(
725
+ model=model,
726
+ tokenizer=tokenizer,
727
+ # Pass the configured generation params as the base config for the wrapper
728
+ generation_config=base_gen_config,
729
+ device=device,
730
+ # Pass GUI self-consistency settings
731
+ # The wrapper uses self_consistency_enabled to decide if it *should* generate >1 chains
732
+ # when num_return_sequences is not explicitly passed to generate().
733
+ # We *are* explicitly passing num_return_sequences to generate(),
734
+ # so these init flags primarily inform internal wrapper behavior/logging.
735
+ self_consistency_enabled=self_consistency_enabled_gui,
736
+ consistency_rounds=num_chains # Inform the wrapper about intended rounds
737
+ # Pass other wrapper-specific init args if needed (e.g., custom tags)
738
+ # final_answer_tag="Final Answer:" # Example if different from default
739
+ )
740
+ generation_status.write("Wrapper initialized.")
741
+ update_telemetry()
742
+
743
+ except Exception as e:
744
+ generation_status.error(f"❌ Failed to initialize CoT wrapper: {e}")
745
+ st.exception(e)
746
+ # Clean up resources in case of failure
747
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
748
+ gc.collect()
749
+ st.stop() # Stop if wrapper initialization fails
750
+
751
+ # --- Generate ---
752
+ # Call the wrapper's generate method.
753
+ # Pass the original input text.
754
+ # Explicitly pass num_return_sequences to request the desired number of chains.
755
+ generation_status.update(label=f"⏳ Generating {num_chains} reasoning chain(s)...", state="running")
756
+ start_time = time.time()
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  try:
759
+ # The wrapper's generate takes input_text and optional config/num_return_sequences overrides.
760
+ # We rely on the wrapper's internal config logic and pass the desired number of sequences.
761
+ outputs = cot_wrapper.generate(
762
+ input_text=prompt,
763
+ # Optional: Pass a GenerationConfig override for this specific call if needed, e.g.:
764
+ # generation_config=GenerationConfig(temperature=temperature + 0.1),
765
+ # Pass the requested number of chains directly to the generate method:
766
+ num_return_sequences=num_chains,
767
+ )
768
+ # Expected `outputs` dict structure: {'full_texts': [...], 'reasoning_steps': [...], 'final_answers': [...], 'generation_scores': [...]}
769
+
770
+ except Exception as e:
771
+ generation_status.error(f"❌ Generation failed: {e}")
772
+ st.exception(e) # Display the full exception traceback
773
+ # Clean up resources after potential OOM or other errors
774
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
775
+ gc.collect() # Python garbage collection
776
+ # No need to st.stop() here, error message is displayed
777
+ outputs = None # Ensure outputs is None if generation fails
778
+
779
+ elapsed_time = time.time() - start_time
780
+
781
+ # Update final status based on success or failure
782
+ if outputs is not None:
783
+ generation_status.update(label=f"✨ Generation complete in {elapsed_time:.2f}s", state="complete")
784
+ else:
785
+ generation_status.update(label=f"❌ Generation failed after {elapsed_time:.2f}s", state="error")
786
+
787
+ update_telemetry() # Final telemetry update after generation attempt
788
+
789
+
790
+ # --- Process and Display Results ---
791
+ if outputs is not None: # Only display if generation was attempted and returned results
792
+ with results_container:
793
+ st.markdown("## 📚 Reasoning Output")
794
+
795
+ # Display Self-Consistency Consensus first if enabled and results are available
796
+ # Implement the voting logic here in the GUI using the wrapper's output
797
+ final_answers_list = outputs.get('final_answers', [])
798
+
799
+ if self_consistency_enabled_gui and final_answers_list:
800
+ # Perform the actual voting
801
+ consensus_answer, answer_distribution_dict = perform_self_consistency_voting(final_answers_list)
802
+ # Convert dict to Counter for sorting by count in display
803
+ answer_distribution = Counter(answer_distribution_dict)
804
+
805
+
806
+ st.markdown('<div class="consensus-answer">', unsafe_allow_html=True)
807
+ st.write("💡 **Consensus Answer (Self-Consistency):**")
808
+ if consensus_answer:
809
+ st.write(f'<p>{consensus_answer}</p>', unsafe_allow_html=True)
810
+ # Optional: Add a note about the confidence/number of votes for the winner
811
+ # We need the count of the winning answer from the distribution
812
+ winner_count = answer_distribution.get(normalize_answer(consensus_answer), 0) # Get count for the winning normalized answer
813
+ st.write(f"*(Based on {winner_count} {'vote' if winner_count == 1 else 'votes'} out of {len(final_answers_list)} chains)*")
814
+ else:
815
+ st.write("<p>[Could not determine consensus - no valid answers found]</p>", unsafe_allow_html=True)
816
+ st.write(f"*(Examined {len(final_answers_list)} chains)*")
817
+
818
+ st.markdown('</div>', unsafe_allow_html=True)
819
+
820
+ # Display answer distribution if there's more than one unique answer after normalization
821
+ if len(answer_distribution) > 1:
822
+ st.markdown("###### Answer Distribution:")
823
+ # Display sorted distribution by vote count
824
+ for ans, count in answer_distribution.most_common():
825
+ # Display the normalized answer and its count
826
+ st.write(f"- '{ans}' ({count} {'vote' if count == 1 else 'votes'})")
827
+ elif len(answer_distribution) == 1 and consensus_answer:
828
+ st.info(f"All {len(final_answers_list)} valid chains agreed on the normalized answer: '{consensus_answer}'.")
829
+ else:
830
+ st.warning(f"No valid answers ({len(answer_distribution)} unique normalized answers) were found to determine distribution from {len(final_answers_list)} chains.")
831
+
832
+
833
+ st.markdown("---") # Separator after consensus section
834
+
835
+
836
+ # Display individual chains
837
+ full_texts = outputs.get('full_texts', [])
838
+ reasoning_steps_list = outputs.get('reasoning_steps', [])
839
+ final_answers_list_raw = outputs.get('final_answers', []) # Keep raw answers for display
840
+
841
+ if not full_texts:
842
+ st.warning("No reasoning chains were generated or parsed successfully.")
843
+ else:
844
+ st.markdown(f"### Individual Chains ({len(full_texts)} generated)")
845
+ # Iterate and display each chain in an expander
846
+ # Ensure lists are iterable and have consistent length, padding with placeholders if necessary
847
+ max_len_outputs = len(full_texts)
848
+ # Ensure reasoning_steps_list and final_answers_list_raw match the length of full_texts
849
+ reasoning_steps_list = (reasoning_steps_list if isinstance(reasoning_steps_list, list) else []) + [[]] * (max_len_outputs - len(reasoning_steps_list))
850
+ final_answers_list_raw = (final_answers_list_raw if isinstance(final_answers_list_raw, list) else []) + ["[N/A - Parsing Failed]"] * (max_len_outputs - len(final_answers_list_raw))
851
+
852
+
853
+ for idx, (text, steps, ans_raw) in enumerate(zip(full_texts, reasoning_steps_list, final_answers_list_raw), 1):
854
+ # Use try-except for displaying each chain just in case of unexpected data format
855
+ try:
856
+ # Expander for each chain, starting collapsed
857
+ with st.expander(f"Chain {idx}", expanded=False):
858
+ # Display full generated text
859
+ st.markdown('<div class="output-label">Full Generated Text (Cleaned):</div>', unsafe_allow_html=True)
860
+ st.text_area(f"chain_text_area_{idx}", text if isinstance(text, str) else "[Invalid Text Data]", height=250, label_visibility="collapsed", help="The complete generated output for this chain after cleaning artifacts.", key=f"chain_text_{idx}") # Added key
861
+
862
+ # Display parsed reasoning steps
863
+ st.markdown('<div class="output-label">Reasoning Steps Parsed:</div>', unsafe_allow_html=True)
864
+ if steps and isinstance(steps, list) and len(steps) > 0:
865
+ # Display steps as a list with strong emphasis on step number
866
+ for i, step in enumerate(steps, 1):
867
+ if isinstance(step, str) and step.strip():
868
+ st.markdown(f"**Step {i}:** {step.strip()}")
869
+ elif not isinstance(step, str):
870
+ st.warning(f"Step {i} has invalid format in Chain {idx}.")
871
+ elif isinstance(steps, list) and len(steps) == 0:
872
+ st.info("No specific steps were extracted for this chain.")
873
+ else:
874
+ st.warning(f"Reasoning steps data is invalid or missing for Chain {idx}.")
875
+
876
+
877
+ # Display parsed final answer (raw)
878
+ st.markdown('<div class="output-label">Final Answer Parsed:</div>', unsafe_allow_html=True)
879
+ display_answer = ans_raw if isinstance(ans_raw, str) and ans_raw.strip() else "[No answer extracted]"
880
+ st.write(f"**{display_answer}**")
881
+
882
+ # Optional: Add a separator between chain sections within the expander if desired
883
+ # st.markdown("---")
884
+
885
+ except Exception as chain_e:
886
+ st.error(f"Error displaying content for Chain {idx}: {chain_e}")
887
+ st.exception(chain_e)
888
+
889
+ # Final separator after all chains
890
+ st.markdown("---")
891
+ st.info(f"Displayed details for {len(full_texts)} generated chains.")
892
+
893
+ # Clean up GPU memory after generation and display are complete
894
+ if torch.cuda.is_available():
895
+ torch.cuda.empty_cache()
896
+ gc.collect() # Python garbage collection
897
+ # st.write("GPU memory cache cleared and garbage collected.") # Optional status message
chain_of_thought_wrapper.py CHANGED
@@ -3,47 +3,60 @@
3
  import re
4
  import torch
5
  import logging
6
- from transformers import PreTrainedModel, AutoTokenizer, GenerationConfig, GenerationMixin
 
 
 
 
 
 
7
  from transformers.utils import is_accelerate_available, is_bitsandbytes_available
8
  from typing import Optional, List, Tuple, Dict, Union, Any
9
  import gc # Import garbage collector for cleanup
10
- import time
 
11
 
12
  # --- Logging Setup ---
13
- # Configure logging for the module
14
- logging.basicConfig(level=logging.INFO) # Default logging level
 
15
  logger = logging.getLogger(__name__)
16
- # Prevent duplicate handlers if imported multiple times
17
  if not logger.handlers:
18
  handler = logging.StreamHandler()
19
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
  handler.setFormatter(formatter)
21
  logger.addHandler(handler)
22
- logger.propagate = False # Prevent logs from going to root logger multiple times
23
-
24
 
25
  # --- Default Configuration Values ---
26
- # These defaults provide sensible starting points for the wrapper's behavior.
27
- DEFAULT_MAX_LENGTH = 1024 # Default maximum length of the generated output sequence.
28
- DEFAULT_REASONING_LIMIT = 10 # Limit on the number of steps to extract during parsing (currently unused in parse logic, but good to keep as a concept).
29
- DEFAULT_CONSISTENCY_ROUNDS = 3 # Default number of chains to generate for self-consistency (used in __init__, passed via GUI num_chains).
30
- DEFAULT_COMPLEXITY_KEYWORDS = ["explain", "step by step", "plan", "analyze", "reasoning", "logic"] # Keywords to potentially trigger CoT (currently unused, CoT is always on).
31
- DEFAULT_FINAL_ANSWER_TAG = "Final_Answer:" # The specific tag expected before the final answer.
 
32
 
33
  # --- Regex Pattern for Parsing Steps ---
34
  # This pattern is used to identify and extract individual reasoning steps from
35
- # the generated text. It's designed to be flexible, capturing:
36
- # - "Step N:"
37
- # - "Step N."
38
- # - "Step N-"
39
- # - "N:"
40
- # - "N."
41
- # - "N-"
42
- # Where N is one or more digits, case-insensitive for "Step".
43
  DEFAULT_STEP_PATTERN = re.compile(
44
- r"^(?:Step\s*\d+[:.)-]|\d+[:.)-])\s*(.*)", re.IGNORECASE
45
  )
46
 
 
 
 
 
 
 
 
 
 
47
 
48
  class ChainOfThoughtWrapper:
49
  """
@@ -53,15 +66,20 @@ class ChainOfThoughtWrapper:
53
  template into the prompt. It handles model generation and parses the
54
  output to extract reasoning steps and a final answer. It is designed
55
  to generate multiple sequences for potential Self-Consistency voting
56
- (voting logic is expected to be handled by the calling application,
57
- like the Streamlit GUI).
 
 
 
 
58
 
59
  Key Features:
60
- - Forces CoT via prompt injection.
61
- - Parses structured reasoning steps and final answer from output.
62
- - Supports generating multiple chains for Self-Consistency analysis.
63
- - Compatible with Hugging Face PreTrainedModels or objects implementing `.generate()`.
64
- - Handles device placement and merges GenerationConfig.
 
65
  """
66
 
67
  def __init__(
@@ -71,143 +89,179 @@ class ChainOfThoughtWrapper:
71
  generation_config: Optional[GenerationConfig] = None,
72
  device: Optional[str] = None,
73
  max_length: int = DEFAULT_MAX_LENGTH,
74
- reasoning_steps_limit: int = DEFAULT_REASONING_LIMIT, # Parameter included as per provided code
75
- self_consistency: bool = False, # Parameter included as per provided code (__init__ attribute)
76
- consistency_rounds: int = DEFAULT_CONSISTENCY_ROUNDS, # Parameter included as per provided code (__init__ attribute)
77
- complexity_keywords: Optional[List[str]] = None, # Parameter included as per provided code
78
  final_answer_tag: str = DEFAULT_FINAL_ANSWER_TAG,
79
- # self_consistency_enabled: bool = False # Removed this based on user's 'keep as is' and gui interaction
 
 
 
 
 
 
 
 
80
  ):
81
  """
82
- Initializes the ChainOfThoughtWrapper.
83
 
84
  Args:
85
  model (Union[PreTrainedModel, GenerationMixin, Any]): The language model.
86
- Must have a `.generate()` method.
87
  tokenizer (AutoTokenizer): The corresponding tokenizer.
88
  generation_config (Optional[GenerationConfig]): A default generation configuration.
89
- Values here can be overridden by `generate()` call.
90
  device (Optional[str]): The device to load the model onto ('cpu' or 'cuda').
91
  Defaults to 'cuda' if available, otherwise 'cpu'.
92
  max_length (int): The maximum total length of the input + generated sequence.
93
- reasoning_steps_limit (int): Conceptual limit for parsed steps (currently not enforced in _parse).
94
- self_consistency (bool): Flag indicating if self-consistency is intended (Informs `consistency_rounds` attribute).
95
- consistency_rounds (int): The number of chains to generate if self-consistency is active (Informs `consistency_rounds` attribute).
96
- The actual number generated is controlled by `num_return_sequences` in `generate()` or `generation_config`.
97
- complexity_keywords (Optional[List[str]]): List of keywords to potentially trigger CoT (currently unused).
 
 
98
  final_answer_tag (str): The specific string marker expected before the final answer.
 
 
 
 
 
 
99
  """
100
- # Determine and set the device
 
101
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
102
- logger.info("Initializing wrapper on device: %s", self.device)
103
 
104
- # Move the model to the specified device
 
 
105
  try:
106
  self.model = model.to(self.device)
107
- self.model.eval() # Set model to evaluation mode for consistent behavior
108
  logger.info("Model moved to %s and set to eval mode.", self.device)
109
  except Exception as e:
110
  logger.error("Failed to move model to device %s: %s", self.device, e)
111
- raise # Re-raise the exception after logging
112
 
113
  self.tokenizer = tokenizer
114
 
115
- # Set core parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  self.max_length = max_length
117
  self.reasoning_steps_limit = reasoning_steps_limit
118
- self.self_consistency = self_consistency # Attribute stored, actual generation count controlled elsewhere
119
- self.consistency_rounds = max(1, consistency_rounds) if self_consistency else 1 # Attribute stored
120
- self.complexity_keywords = complexity_keywords or list(DEFAULT_COMPLEXITY_KEYWORDS) # Ensure it's a mutable list
 
 
 
 
121
  self.final_answer_tag = final_answer_tag
122
- # Compile regex pattern for final answer extraction
 
 
 
 
 
 
 
 
 
 
123
  self.final_answer_pattern = re.compile(
124
  re.escape(final_answer_tag) + r"\s*(.*)", re.IGNORECASE | re.DOTALL
125
  )
126
- logger.debug("Final answer pattern compiled: %s", self.final_answer_pattern.pattern)
127
- logger.debug("Step pattern: %s", DEFAULT_STEP_PATTERN.pattern)
128
-
129
- # Attempt to find the underlying Hugging Face model and its config
130
- # This is useful for accessing standard attributes like eos_token_id, etc.
131
- self._hf_model, self._hf_config = self._find_hf_model_and_config(self.model)
132
-
133
- # Fallback to tokenizer settings if HF config isn't found
134
- if self._hf_config is None:
135
- logger.warning("Underlying HF model config not found. Relying on tokenizer for eos/pad tokens and vocab size.")
136
- # Create a pseudo-config with essential tokenizer info
137
- class PseudoConfig:
138
- def __init__(self, tok):
139
- self.eos_token_id = tok.eos_token_id
140
- # Use eos_token_id as pad_token_id if pad_token_id is None (common for GPT-like models)
141
- self.pad_token_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
142
- # Fallback if both are None (less common but possible)
143
- if self.pad_token_id is None:
144
- logger.warning("Tokenizer pad_token_id and eos_token_id are both None. Generation might be unstable without padding.")
145
- # Assign a arbitrary value or handle externally if this happens in practice
146
- # For now, keep it None, generation might fail or behave unexpectedly
147
- pass # Keep pad_token_id as None
148
-
149
- self.vocab_size = len(tok) # Vocabulary size from tokenizer
150
-
151
- def __getattr__(self, name):
152
- # Allow accessing other attributes, returning None if not found
153
- # This prevents errors if generation_config tries to read something unexpected
154
- logger.debug("Accessing undefined attribute '%s' on PseudoConfig. Returning None.", name)
155
- return None
156
-
157
- self._hf_config = PseudoConfig(self.tokenizer)
158
- logger.debug("Created PseudoConfig: eos_token_id=%s, pad_token_id=%s, vocab_size=%s",
159
- self._hf_config.eos_token_id, self._hf_config.pad_token_id, self._hf_config.vocab_size)
160
- else:
161
- logger.info("Found underlying HF model config.")
162
- logger.debug("HF Config: eos_token_id=%s, pad_token_id=%s, vocab_size=%s",
163
- getattr(self._hf_config, 'eos_token_id', None),
164
- getattr(self._hf_config, 'pad_token_id', None),
165
- getattr(self._hf_config, 'vocab_size', None))
166
 
 
 
167
 
168
- # --- Setup Generation Config ---
169
- # Start with a base config, either provided or a default one
 
 
170
  if generation_config:
171
- # Use from_dict and to_dict for safe merging/copying of GenerationConfig
172
- self.generation_config = GenerationConfig.from_dict(generation_config.to_dict())
173
- logger.info("Initialized with provided GenerationConfig.")
174
  else:
175
- # Create a default GenerationConfig using info from HF config or tokenizer fallback
176
- self.generation_config = GenerationConfig(
177
- eos_token_id=self._hf_config.eos_token_id,
178
- pad_token_id=self._hf_config.pad_token_id,
179
- max_length=self.max_length, # Set max_length from wrapper param
180
- # Add other common defaults if not provided
181
- do_sample=True,
182
- temperature=0.7,
183
- top_p=0.95,
184
- top_k=50,
185
- num_return_sequences=1, # Default to 1 sequence
186
- no_repeat_ngram_size=0, # Default to no ngram repetition prevention
 
 
 
187
  )
188
- logger.info("Initialized with default GenerationConfig.")
189
-
190
- # Ensure the underlying HF model (if found) is set to return dict outputs from generate
191
- # This is necessary for accessing scores, hidden states etc. if needed, and for consistency.
192
- # Use a check as some custom models might not have this attribute on their config.
193
- if hasattr(self._hf_config, 'return_dict_in_generate'):
194
- try:
195
- setattr(self._hf_config, 'return_dict_in_generate', True)
196
- logger.debug("Set _hf_config.return_dict_in_generate = True.")
197
- except Exception as e:
198
- logger.warning("Failed to set return_dict_in_generate on _hf_config: %s", e)
 
 
 
 
 
 
199
  else:
200
- logger.debug("_hf_config does not have return_dict_in_generate attribute.")
201
 
202
 
203
- logger.info("ChainOfThoughtWrapper initialization complete on device: %s", self.device)
204
- logger.debug("Initial GenerationConfig: %s", self.generation_config.to_dict())
205
 
206
 
207
  def _find_hf_model_and_config(self, obj: Any) -> Tuple[Optional[PreTrainedModel], Optional[Any]]:
208
  """
209
  Recursively searches for an underlying Hugging Face PreTrainedModel
210
- and its configuration within a potentially wrapped object.
 
 
211
 
212
  Args:
213
  obj (Any): The object to inspect (could be the model itself or a wrapper).
@@ -216,14 +270,21 @@ class ChainOfThoughtWrapper:
216
  Tuple[Optional[PreTrainedModel], Optional[Any]]: The found HF model instance and its config.
217
  Returns (None, None) if not found.
218
  """
 
 
 
 
 
 
219
  logger.debug("Searching for HF model in object of type: %s", type(obj))
220
  # If the object is directly a PreTrainedModel and has a config
221
- if isinstance(obj, PreTrainedModel) and hasattr(obj, 'config'):
222
  logger.debug("Found HF PreTrainedModel directly.")
223
- return obj, obj.config
 
224
 
225
  # Check common attribute names where the base model might be stored
226
- potential_attrs = ('model', 'base_model', 'transformer', 'hf_model')
227
  for attr_name in potential_attrs:
228
  m = getattr(obj, attr_name, None)
229
  if m is not None:
@@ -231,22 +292,25 @@ class ChainOfThoughtWrapper:
231
  # Recursively search within the attribute
232
  found_model, found_config = self._find_hf_model_and_config(m)
233
  if found_model or found_config:
 
234
  return found_model, found_config
235
 
236
- # If no PreTrainedModel found, check if the object itself has a 'config' attribute
237
  if hasattr(obj, 'config'):
238
- logger.debug("Found config attribute on object, but no PreTrainedModel.")
239
- return None, obj.config
 
240
 
241
- logger.debug("No HF PreTrainedModel or config found.")
 
242
  return None, None
243
 
244
 
245
  def _inject_cot(self, prompt: str) -> str:
246
  """
247
- Injects the prescriptive Chain-of-Thought template into the user's prompt.
248
-
249
- This method defines the expected format the model should follow for reasoning.
250
 
251
  Args:
252
  prompt (str): The original user prompt.
@@ -254,178 +318,239 @@ class ChainOfThoughtWrapper:
254
  Returns:
255
  str: The prompt with the CoT template appended.
256
  """
257
- # The template strongly guides the model to produce step-by-step reasoning
258
- # followed by a specific tag for the final answer.
259
- cot_prompt = (
260
- f"{prompt.strip()}\n\n" # Use strip() to clean user prompt
261
- "Let's analyze this problem logically, breaking it down step by step to reach the precise final answer.\n\n" # Enhanced instruction
262
- "Reasoning Process:\n\n" # Clearer heading for steps
263
- "Step 1: " # Start the first step explicitly
264
- # More steps are not needed here, the model learns to continue the pattern
265
- )
266
- logger.debug("Injected CoT template. Full prompt starts with: %s...", cot_prompt[:100].replace('\n', '\\n'))
267
- return cot_prompt
268
 
 
 
269
 
270
- @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def generate(
272
  self,
273
- input_ids: torch.LongTensor,
274
- attention_mask: Optional[torch.LongTensor] = None,
275
- generation_config: Optional[GenerationConfig] = None,
276
- num_return_sequences: int = 1, # This argument controls how many sequences are generated
277
- **kwargs: Any # Allows passing arbitrary generation parameters
278
  ) -> Dict[str, Any]:
279
  """
280
- Generates text using the wrapped model, enforcing Chain-of-Thought.
281
-
282
- This method prepares the input by injecting the CoT template, calls the
283
- underlying model's generate method, and then parses the raw outputs
284
- to extract structured reasoning steps and final answers.
285
 
286
  Args:
287
- input_ids (torch.LongTensor): Tokenized input prompt (batch size 1 expected).
288
- Shape [1, sequence_length].
289
- attention_mask (Optional[torch.LongTensor]): Attention mask for the input.
290
- Shape [1, sequence_length].
291
- generation_config (Optional[GenerationConfig]): Specific generation config
292
- for this call. Overrides defaults.
293
- num_return_sequences (int): The number of independent sequences to generate.
294
- This is crucial for Self-Consistency.
295
- Comes from the GUI's 'num_chains'.
296
- **kwargs (Any): Additional keyword arguments passed to the model's `generate` method.
297
 
298
  Returns:
299
- Dict[str, Any]: A dictionary containing:
300
- - 'sequences' (torch.LongTensor): The raw generated token sequences.
301
- - 'full_texts' (List[str]): The complete decoded text for each sequence.
302
- - 'reasoning_steps' (List[List[str]]): List of parsed reasoning steps for each sequence.
303
- - 'final_answers' (List[str]): List of parsed final answers for each sequence.
304
- - 'consensus_answer' (Optional[str]): The consensus answer if self-consistency is active and possible (Handled by calling code).
305
  """
306
- # Ensure input is on the correct device
307
- input_ids = input_ids.to(self.device)
308
- if attention_mask is not None:
309
- attention_mask = attention_mask.to(self.device)
310
-
311
- # Decode the original prompt text for CoT injection
312
- # Assume batch size is 1 for the input prompt tensor [1, sequence_length]
313
- if input_ids.size(0) != 1:
314
- logger.warning("Batch size > 1 detected for input_ids (%d). CoT injection assumes batch size 1. Using the first item.", input_ids.size(0))
315
- prompt_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
316
-
317
- # --- Inject CoT Template ---
318
- # This is the core step that forces the model into a reasoning mode.
319
- cot_prompt = self._inject_cot(prompt_text)
320
- logger.debug("Injected CoT prompt. Encoding...")
321
-
322
- # --- Prepare Generation Configuration ---
323
- # Merge the wrapper's default config with the call-specific config and kwargs.
324
- # The num_return_sequences from the function argument takes precedence here.
325
- cfg = GenerationConfig.from_dict(self.generation_config.to_dict()) # Start with wrapper's default
326
- if generation_config:
327
- cfg.update(**generation_config.to_dict()) # Update with call-specific config
328
- logger.debug("Updated GenerationConfig with call-specific config.")
329
-
330
- # Explicitly set num_return_sequences from the function argument
331
- cfg.num_return_sequences = num_return_sequences
332
- logger.info("Generating %d sequence(s).", cfg.num_return_sequences)
333
-
334
- # Update with any remaining keyword arguments passed to generate()
335
- for k, v in kwargs.items():
336
- if hasattr(cfg, k):
337
- setattr(cfg, k, v)
338
- logger.debug("Updating GenerationConfig kwarg: %s=%s", k, v)
339
- else:
340
- # Allow passing arbitrary kwargs to model.generate if the underlying method supports them
341
- # These won't be part of the GenerationConfig object itself unless it's a supported param.
342
- # However, the model's generate method might accept extra args.
343
- # Log a warning if it's not a standard GenerationConfig parameter.
344
- if k not in GenerationConfig().__dict__: # Check if it's NOT a standard param
345
- logger.debug("Passing non-standard kwarg '%s' to model.generate.", k)
346
- # We pass all kwargs to model.generate below anyway.
347
-
348
- logger.debug("Final GenerationConfig for call: %s", cfg.to_dict())
349
-
350
- # --- Encode the CoT Prompt ---
351
- # Max length for input should be total max_length minus max_new_tokens
352
- # to leave space for the generation.
353
- # Ensure padding and truncation are handled.
354
  try:
355
- enc = self.tokenizer(
356
- cot_prompt,
357
- return_tensors='pt',
358
- padding='longest', # Pad to the longest sequence in the batch (always 1 here)
359
- truncation=True, # Crucially, truncate if the prompt is too long
360
- max_length=self.max_length - cfg.max_new_tokens # Leave room for generation
361
  ).to(self.device)
362
- logger.debug("Encoded CoT prompt. Input shape: %s", enc['input_ids'].shape)
363
-
364
  except Exception as e:
365
- logger.error("Failed to encode CoT prompt: %s", e)
366
- raise # Re-raise the exception after logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
 
369
- # --- Generate Text ---
370
- # Call the underlying model's generate method with the prepared input and config.
371
- # torch.no_grad() context is already applied to the whole method.
372
  try:
373
- logger.info("Calling model.generate()...")
374
- start_time = time.time() # Measure generation time
375
- out = self.model.generate(
376
- input_ids=enc['input_ids'],
377
- attention_mask=enc['attention_mask'],
378
- generation_config=cfg,
379
- **kwargs # Pass through any extra kwargs
380
  )
381
- elapsed_time = time.time() - start_time
382
- logger.info("model.generate() finished in %.2f seconds.", elapsed_time)
383
- logger.debug("Raw output shape: %s", out.shape)
384
-
 
 
 
385
 
386
  except Exception as e:
387
  logger.error("Model generation failed: %s", e)
388
- # Attempt to clean up GPU memory in case of OOM or other errors
389
- if torch.cuda.is_available():
390
- torch.cuda.empty_cache()
391
- gc.collect() # Trigger Python garbage collection
392
- raise # Re-raise the exception after logging
393
 
394
- # --- Decode and Parse Outputs ---
395
- # Decode the generated token sequences back into text.
396
- logger.debug("Decoding and parsing outputs...")
397
- decoded_outputs = self.tokenizer.batch_decode(out, skip_special_tokens=True)
398
-
399
- # Process each decoded output to extract steps and final answer
400
- parsed_results = [self._parse(text, cot_prompt) for text in decoded_outputs]
401
-
402
- # Separate the parsed components into lists
403
- all_steps = [r[0] for r in parsed_results]
404
- all_finals = [r[1] for r in parsed_results]
405
- all_full_texts = [r[2] for r in parsed_results] # The 'body' after removing template
406
-
407
- logger.info("Generated and parsed %d sequence(s).", len(decoded_outputs))
408
-
409
- # --- Return Results ---
410
- # The calling code (e.g., the GUI) is responsible for implementing
411
- # Self-Consistency voting based on the list of 'final_answers' provided here.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  return {
413
- 'sequences': out, # Return raw sequences in case they are needed
414
- 'full_texts': all_full_texts, # Text body after template removal
415
- 'reasoning_steps': all_steps,
416
- 'final_answers': all_finals,
417
- # 'consensus_answer' is not computed here, it's done externally.
418
- # Keeping the structure consistent with GUI expectation.
419
- 'consensus_answer': None # Placeholder, computed externally
420
  }
421
 
422
 
423
  def _parse(self, text: str, cot_prompt: str) -> Tuple[List[str], str, str]:
424
  """
425
  Parses the generated text to extract reasoning steps and the final answer.
426
-
427
- Applies regex patterns to find lines matching the step format and the
428
- final answer tag. Includes cleanup for stray model artifacts.
429
 
430
  Args:
431
  text (str): The raw text output from the model for a single chain.
@@ -435,229 +560,368 @@ class ChainOfThoughtWrapper:
435
  Tuple[List[str], str, str]: A tuple containing:
436
  - A list of extracted reasoning step strings.
437
  - The extracted final answer string.
438
- - The full body of the generated text (after removing the prompt).
439
  """
440
- logger.debug("Parsing generated text...")
441
 
442
- # Remove the exact injected prompt from the beginning of the text.
443
  # This isolates the model's generated continuation.
444
  body = text
445
  if text.startswith(cot_prompt):
446
- body = text[len(cot_prompt):].strip()
447
- logger.debug("Removed CoT prompt (%d characters) from beginning.", len(cot_prompt))
 
 
 
 
 
 
 
 
 
 
 
 
448
  else:
449
- logger.warning("Generated text does not start with the injected CoT prompt. Parsing entire text.")
450
- body = text.strip() # Just strip whitespace if template wasn't followed
451
-
452
- # --- Cleanup stray model artifacts ---
453
- # Remove common problematic tags or partial JSON structures that models sometimes emit.
454
- # This makes the raw output cleaner before step/answer extraction.
455
- logger.debug("Cleaning stray artifacts...")
456
- body = re.sub(r"<init>.*?</init>", "", body, flags=re.DOTALL)
457
- body = re.sub(r"<final_output>.*?</final_output>", "", body, flags=re.DOTALL)
458
- # Note: Removing all {} might be aggressive if model uses them naturally.
459
- # Keeping it as per provided code, but be aware this could remove desired output.
460
- # Consider making this optional or more specific if needed.
461
- body = re.sub(r"\{.*?\}", "", body, flags=re.DOTALL)
462
- logger.debug("Artifact cleanup complete.")
463
-
464
-
465
- lines = [l.strip() for l in body.splitlines() if l.strip()] # Split into non-empty, stripped lines
466
  steps = [] # List to store extracted steps
467
  final_answer = "" # Variable to store the final answer
 
468
 
469
- # --- Extract Steps and Final Answer ---
470
  # Iterate through lines and apply regex patterns.
471
- found_final_answer_line = False
472
- for i, line in enumerate(lines):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  # Check for reasoning step pattern
474
- step_match = DEFAULT_STEP_PATTERN.match(line)
475
  if step_match:
476
- # If a step is found, add the captured group (the text after the number/tag)
477
- steps.append(step_match.group(1).strip())
478
- logger.debug("Extracted step %d: '%s'", len(steps), steps[-1][:50])
479
- # Stop adding steps if we've reached a defined limit (though limit isn't currently enforced after parsing)
480
- # if len(steps) >= self.reasoning_steps_limit:
481
- # logger.debug("Reached reasoning steps limit (%d). Stopping step extraction.", self.reasoning_steps_limit)
482
- # # Continue iterating to potentially find the final answer after the limit
483
- # # break # DO NOT break if we still need to find the final answer tag after the limit
484
-
485
-
486
- else:
487
- # If it's not a step, check for the final answer tag
488
- final_answer_match = self.final_answer_pattern.search(line)
489
- if final_answer_match:
490
- # If the final answer tag is found, extract the text following it
491
- final_answer = final_answer_match.group(1).strip()
492
- logger.debug("Extracted final answer tagged: '%s'", final_answer[:50])
493
- found_final_answer_line = True
494
- # Once the final answer tag is found, we can stop processing lines for *this specific pattern*
495
- # However, the provided code breaks the loop entirely here.
496
- # Keeping the break to match the original logic.
497
- break # Stop processing lines after finding the tagged answer
498
-
499
- # --- Fallback for Final Answer ---
500
- # If the specific final answer tag was not found, assume the last non-step line
501
- # is the intended final answer. This is a heuristic fallback.
502
- if not found_final_answer_line:
503
- logger.debug("Final answer tag not found. Applying fallback heuristic.")
504
- # Find the last line that is not a step
505
  last_non_step_line = ""
506
- for line in reversed(lines): # Iterate backwards
507
- if line.strip() and not DEFAULT_STEP_PATTERN.match(line):
508
  last_non_step_line = line.strip()
509
- logger.debug("Fallback: Last non-step line found: '%s'", last_non_step_line[:50])
510
- break # Found the last non-step line
511
 
512
  if last_non_step_line:
513
  # Check if the last non-step line *contains* the final answer tag,
514
- # even if it didn't *start* with it or was the last line processed.
515
- # This handles cases where the tag might be mid-line or in a different format.
516
  fa_match_fallback = self.final_answer_pattern.search(last_non_step_line)
517
  if fa_match_fallback:
518
  final_answer = fa_match_fallback.group(1).strip()
519
- logger.debug("Fallback found tagged answer in last non-step line: '%s'", final_answer[:50])
520
  else:
521
- # If no tag in the last non-step line, just use the line itself
522
  final_answer = last_non_step_line
523
- logger.debug("Fallback using last non-step line as answer: '%s'", final_answer[:50])
524
  else:
525
- # If no non-empty lines were found, the final answer is empty
526
  final_answer = ""
527
- logger.debug("No lines found in body. Final answer is empty.")
 
 
 
 
 
 
 
 
 
 
 
528
 
 
 
 
 
 
 
 
 
 
529
 
530
- logger.debug("Parsing complete. Steps found: %d, Final Answer: '%s'", len(steps), final_answer[:50])
531
 
532
- return steps, final_answer, body # Return steps, final answer, and the cleaned body text
 
 
 
533
 
534
 
535
  def resize_token_embeddings(self, new_size: int):
536
  """
537
- Resizes the model's token embeddings, useful after adding new tokens
538
- to the tokenizer (like a custom PAD token).
 
 
539
 
540
- Only works if the underlying model object has a `resize_token_embeddings` method.
 
541
 
542
  Args:
543
  new_size (int): The new size of the vocabulary/embedding layer.
544
- Should match the size of the tokenizer's vocabulary.
545
  """
546
- # Find the actual HF model if wrapped
547
- hf_model_instance, _ = self._find_hf_model_and_config(self.model)
548
 
549
- if hasattr(hf_model_instance, 'resize_token_embeddings'):
550
  try:
551
  old_size = hf_model_instance.get_input_embeddings().weight.size(0)
552
  if new_size != old_size:
 
 
 
553
  hf_model_instance.resize_token_embeddings(new_size)
554
- logger.info("Resized model token embeddings from %d to %d.", old_size, new_size)
555
  # Update model config's vocab size if available
556
  if hasattr(hf_model_instance, 'config') and hasattr(hf_model_instance.config, 'vocab_size'):
557
- hf_model_instance.config.vocab_size = new_size
558
- logger.debug("Updated model config vocab_size to %d.", new_size)
 
 
 
559
  else:
560
  logger.info("Embedding size is already %d, no resizing needed.", new_size)
561
  except Exception as e:
562
  logger.error("Failed to resize token embeddings: %s", e)
563
- # Attempt cleanup
564
  if torch.cuda.is_available(): torch.cuda.empty_cache()
565
  gc.collect()
 
 
 
566
  else:
567
- logger.error("Cannot resize token embeddings: The underlying model object does not have a 'resize_token_embeddings' method.")
568
 
569
 
570
- # Example Usage (Illustrative - requires a real HF model and tokenizer)
571
  if __name__ == "__main__":
572
  print("--- ChainOfThoughtWrapper Example Usage ---")
573
- print("This block requires a Hugging Face model to run.")
574
- print("Loading a small dummy model for demonstration...")
 
575
 
576
  # You would replace this with your actual model loading logic
577
  try:
578
  # Use a tiny, fast model for a quick test
579
- model_id = "hf-internal-testing/tiny-random-gpt2"
 
 
 
 
580
  device = "cuda" if torch.cuda.is_available() else "cpu"
581
 
582
  logger.info(f"Attempting to load model {model_id} on {device}...")
 
 
583
  tokenizer = AutoTokenizer.from_pretrained(model_id)
584
- model = AutoModelForCausalLM.from_pretrained(model_id)
585
 
586
- # Ensure pad token is set for generation (common requirement)
 
587
  if tokenizer.pad_token_id is None:
588
  if tokenizer.eos_token_id is not None:
589
  tokenizer.pad_token_id = tokenizer.eos_token_id
 
590
  else:
591
- # Add a pad token if neither eos nor pad exists
 
 
592
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
593
- model.resize_token_embeddings(len(tokenizer)) # Resize embeddings after adding token
594
  tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
595
- logger.warning("Added and set [PAD] token, resized embeddings.")
 
 
 
 
 
 
 
596
 
597
  # Instantiate the wrapper
598
- # Simulate parameters that would come from the GUI
599
- simulated_gen_config = GenerationConfig(
600
- max_new_tokens=100,
601
- temperature=0.8,
602
- do_sample=True,
603
- num_return_sequences=2, # Simulate asking for 2 chains
 
604
  pad_token_id=tokenizer.pad_token_id, # Pass pad_token_id explicitly
605
  eos_token_id=tokenizer.eos_token_id, # Pass eos_token_id explicitly
 
 
606
  )
607
 
 
 
608
  cot_wrapper = ChainOfThoughtWrapper(
609
  model=model,
610
  tokenizer=tokenizer,
611
- generation_config=simulated_gen_config,
612
  device=device,
613
- self_consistency=True, # Simulate SC enabled
614
- consistency_rounds=2, # Simulate consistency rounds setting
 
 
 
 
615
  )
616
 
617
  # Prepare input prompt
618
- prompt_text = "What is 2 + 2? Think step-by-step."
619
- input_enc = tokenizer(prompt_text, return_tensors='pt').to(device)
620
-
621
  logger.info(f"Generating reasoning for prompt: '{prompt_text}'")
622
 
623
  # Generate outputs
624
- # The num_return_sequences from simulated_gen_config will be used here
 
 
 
 
625
  outputs = cot_wrapper.generate(
626
- input_ids=input_enc['input_ids'],
627
- attention_mask=input_enc['attention_mask']
 
 
628
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
- # Process results (including simulated Self-Consistency voting logic)
631
- print("\n--- Generation Results ---")
632
- for i, (full_text, steps, final_answer) in enumerate(zip(outputs['full_texts'], outputs['reasoning_steps'], outputs['final_answers'])):
633
- print(f"\n--- Chain {i+1} ---")
634
- print("Full Text:")
635
- print(full_text)
636
- print("\nReasoning Steps:")
637
- if steps:
638
- for j, step in enumerate(steps):
639
- print(f" Step {j+1}: {step}")
640
- else:
641
- print(" [No steps parsed]")
642
- print("\nFinal Answer:")
643
- print(f" {final_answer or '[No final answer parsed]'}")
644
 
645
  # --- Simulate Self-Consistency Voting (as would be done in GUI) ---
646
- print("\n--- Self-Consistency Voting ---")
647
- final_answers = [ans for ans in outputs['final_answers'] if ans.strip()] # Filter empty answers
648
- if final_answers:
649
- answer_counts = Counter(final_answers)
650
- most_common_answer, count = answer_counts.most_common(1)[0]
651
- print(f"Raw Answers Submitted for Voting: {final_answers}")
652
- print(f"Answer Counts: {dict(answer_counts)}")
653
- print(f"Consensus Answer: '{most_common_answer}' (Voted by {count} chain(s))")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  else:
655
- print("No valid final answers found for voting.")
656
 
657
 
658
  except Exception as e:
659
- logger.error("Example usage failed: %s", e)
660
  import traceback
661
  traceback.print_exc() # Print detailed traceback for the example failure
662
 
663
- print("\n--- Example Usage End ---")
 
 
 
 
 
 
 
 
 
 
3
  import re
4
  import torch
5
  import logging
6
+ from transformers import (
7
+ PreTrainedModel,
8
+ AutoTokenizer,
9
+ GenerationConfig,
10
+ GenerationMixin,
11
+ AutoModelForCausalLM # Needed for example usage
12
+ )
13
  from transformers.utils import is_accelerate_available, is_bitsandbytes_available
14
  from typing import Optional, List, Tuple, Dict, Union, Any
15
  import gc # Import garbage collector for cleanup
16
+ import time # Import time for potential timing/logging (unused in final code, but good practice)
17
+ from collections import Counter # Needed for example voting
18
 
19
  # --- Logging Setup ---
20
+ # Configure logging for the module. This helps in debugging and understanding wrapper behavior.
21
+ # Set level to DEBUG temporarily to see the detailed logs added below
22
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23
  logger = logging.getLogger(__name__)
24
+ # Ensure logger doesn't add handlers multiple times if the script is imported repeatedly
25
  if not logger.handlers:
26
  handler = logging.StreamHandler()
27
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
28
  handler.setFormatter(formatter)
29
  logger.addHandler(handler)
30
+ # Avoid propagation to the root logger, preventing duplicate messages
31
+ logger.propagate = False
32
 
33
  # --- Default Configuration Values ---
34
+ # These defaults provide sensible starting points for the wrapper's behavior,
35
+ # based on common practices and the audit recommendations.
36
+ DEFAULT_MAX_LENGTH = 2048 # Increased default max length to accommodate longer CoT
37
+ DEFAULT_REASONING_LIMIT = 15 # A conceptual limit for extracted steps (not strictly enforced by parsing logic)
38
+ DEFAULT_CONSISTENCY_ROUNDS = 5 # Default number of chains for self-consistency, increased based on typical research
39
+ DEFAULT_COMPLEXITY_KEYWORDS = ["explain", "step by step", "plan", "analyze", "reasoning", "logic"] # Keywords (currently unused as CoT is always on)
40
+ DEFAULT_FINAL_ANSWER_TAG = "Final_Answer:" # Explicit tag to signal the final answer
41
 
42
  # --- Regex Pattern for Parsing Steps ---
43
  # This pattern is used to identify and extract individual reasoning steps from
44
+ # the generated text. It's designed to be flexible, capturing common step formats
45
+ # like "Step N:", "N.", etc., case-insensitive for "Step".
46
+ # Captures the text *after* the step marker.
 
 
 
 
 
47
  DEFAULT_STEP_PATTERN = re.compile(
48
+ r"^(?:Step\s*\d+[:.)-]\s*|\d+[:.)-]\s*)(.*)", re.IGNORECASE
49
  )
50
 
51
+ # --- Common Artifact Cleanup Regex ---
52
+ # Regex patterns to remove common problematic tokens or structures models sometimes emit,
53
+ # which are not part of the desired reasoning or answer. Based on audit suggestion.
54
+ ARTIFACT_PATTERNS = [
55
+ re.compile(r"<init>.*?</init>", re.DOTALL), # Example: DeepSeek R1 init tags
56
+ re.compile(r"<final_output>.*?</final_output>", re.DOTALL), # Example: DeepSeek R1 final output tags
57
+ # re.compile(r"\{.*?\}", re.DOTALL), # Removing all {} might be too aggressive, removed based on re-evaluation.
58
+ # Add other specific artifact patterns here as needed for observed model outputs
59
+ ]
60
 
61
  class ChainOfThoughtWrapper:
62
  """
 
66
  template into the prompt. It handles model generation and parses the
67
  output to extract reasoning steps and a final answer. It is designed
68
  to generate multiple sequences for potential Self-Consistency voting
69
+ (voting logic is intended for the calling application, e.g., a GUI).
70
+
71
+ It incorporates enhancements based on a detailed audit, focusing on
72
+ prompting, decoding, parsing robustness, cross-model compatibility,
73
+ reliability mitigation, and efficiency, while adhering to the "always-on CoT"
74
+ principle.
75
 
76
  Key Features:
77
+ - Forces CoT via a structured, adaptive prompt template.
78
+ - Parses structured reasoning steps and uses robust logic to find the final answer.
79
+ - Supports generating multiple chains for Self-Consistency analysis via GenerationConfig.
80
+ - Handles common cross-model compatibility issues (e.g., pad tokens, device placement).
81
+ - Merges user-provided GenerationConfig with sensible defaults.
82
+ - Includes basic cleanup for common model output artifacts.
83
  """
84
 
85
  def __init__(
 
89
  generation_config: Optional[GenerationConfig] = None,
90
  device: Optional[str] = None,
91
  max_length: int = DEFAULT_MAX_LENGTH,
92
+ reasoning_steps_limit: int = DEFAULT_REASONING_LIMIT,
93
+ self_consistency_enabled: bool = False, # Control if multiple chains are generated
94
+ consistency_rounds: int = DEFAULT_CONSISTENCY_ROUNDS,
95
+ complexity_keywords: Optional[List[str]] = None, # Currently unused as CoT is always on
96
  final_answer_tag: str = DEFAULT_FINAL_ANSWER_TAG,
97
+ # Optional prompt customization for advanced users
98
+ cot_instruction: str = "Let's analyze this problem logically, breaking it down step by step to reach the precise final answer.",
99
+ reasoning_header: str = "Reasoning Process:",
100
+ step_prefix: str = "Step ", # e.g., "Step 1: " - model will ideally continue this
101
+ # Optional reliability controls (simple, prompt-based)
102
+ emphasize_factual: bool = True,
103
+ allow_uncertainty_phrase: Optional[str] = "If information is insufficient or you are unsure, state that clearly.",
104
+ # Optional parsing flexibility
105
+ strip_artifact_patterns: List[re.Pattern] = ARTIFACT_PATTERNS,
106
  ):
107
  """
108
+ Initializes the ChainOfThoughtWrapper with enhanced configurations.
109
 
110
  Args:
111
  model (Union[PreTrainedModel, GenerationMixin, Any]): The language model.
112
+ Must have a .generate() method.
113
  tokenizer (AutoTokenizer): The corresponding tokenizer.
114
  generation_config (Optional[GenerationConfig]): A default generation configuration.
115
+ Values here can be overridden by generate() call.
116
  device (Optional[str]): The device to load the model onto ('cpu' or 'cuda').
117
  Defaults to 'cuda' if available, otherwise 'cpu'.
118
  max_length (int): The maximum total length of the input + generated sequence.
119
+ This should be large enough for the prompt, reasoning, and answer.
120
+ reasoning_steps_limit (int): Conceptual limit for parsed steps. Not strictly enforced by current parsing.
121
+ self_consistency_enabled (bool): If True, enable multi-chain generation for self-consistency.
122
+ consistency_rounds (int): The number of chains to generate if `self_consistency_enabled` is True.
123
+ Actual number of sequences is controlled by `num_return_sequences`
124
+ in the final `GenerationConfig`.
125
+ complexity_keywords (Optional[List[str]]): List of keywords (unused with always-on CoT).
126
  final_answer_tag (str): The specific string marker expected before the final answer.
127
+ cot_instruction (str): The core instruction phrase for CoT.
128
+ reasoning_header (str): The header text before the reasoning steps.
129
+ step_prefix (str): The prefix for the first step.
130
+ emphasize_factual (bool): If True, add prompt text emphasizing factual reasoning.
131
+ allow_uncertainty_phrase (Optional[str]): If provided, add a phrase prompting model to state uncertainty.
132
+ strip_artifact_patterns (List[re.Pattern]): List of regex patterns to remove from model output before parsing.
133
  """
134
+ # --- Device Handling ---
135
+ # Determine and set the device. Log the chosen device.
136
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
137
+ logger.info("Initializing ChainOfThoughtWrapper on device: %s", self.device)
138
 
139
+ # --- Model and Tokenizer Loading and Configuration ---
140
+ # Move the model to the specified device and set to evaluation mode.
141
+ # Includes error handling for device transfer.
142
  try:
143
  self.model = model.to(self.device)
144
+ self.model.eval() # Set model to evaluation mode (disables dropout, etc.)
145
  logger.info("Model moved to %s and set to eval mode.", self.device)
146
  except Exception as e:
147
  logger.error("Failed to move model to device %s: %s", self.device, e)
148
+ raise # Re-raise the exception if device transfer fails
149
 
150
  self.tokenizer = tokenizer
151
 
152
+ # Attempt to find the underlying Hugging Face model instance and its config.
153
+ # This helps reliably access attributes like `config.vocab_size`, `resize_token_embeddings`, etc.
154
+ self._hf_model_instance, self._hf_config = self._find_hf_model_and_config(self.model)
155
+
156
+ # Handle models/tokenizers without a defined pad_token_id.
157
+ # This is crucial for batch generation (like `num_return_sequences`).
158
+ # If the tokenizer doesn't have a pad_token, try to use the eos_token.
159
+ # If neither exists, add a special token and resize embeddings.
160
+ # The wrapper's `resize_token_embeddings` method is called here if a new token is added.
161
+ if self.tokenizer.pad_token_id is None:
162
+ if self.tokenizer.eos_token_id is not None:
163
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
164
+ logger.warning("Tokenizer pad_token_id is None, using eos_token_id (%s) as pad_token_id.", self.tokenizer.eos_token_id)
165
+ else:
166
+ # Fallback: Add a new pad token if neither exists
167
+ logger.warning("Tokenizer pad_token_id and eos_token_id are both None. Adding a [PAD] token.")
168
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
169
+ self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids('[PAD]')
170
+ logger.info("Added new [PAD] token with ID %s.", self.tokenizer.pad_token_id)
171
+ # Resize model embeddings if we added a new token AND we found a base HF model instance
172
+ if self._hf_model_instance:
173
+ self.resize_token_embeddings(len(self.tokenizer)) # Call the instance method
174
+ logger.info("Resized model embeddings to accommodate new PAD token.")
175
+ else:
176
+ logger.warning("Could not resize model embeddings after adding PAD token; underlying HF model instance not found.")
177
+ logger.warning("Ensure the model can handle a larger vocabulary if batching is used.")
178
+
179
+ # --- Configuration Attributes ---
180
  self.max_length = max_length
181
  self.reasoning_steps_limit = reasoning_steps_limit
182
+ # The actual number of sequences to generate is controlled by `num_return_sequences` in the final `GenerationConfig`.
183
+ # We store `consistency_rounds` to potentially inform this value.
184
+ self.self_consistency_enabled = self_consistency_enabled
185
+ self.consistency_rounds = max(1, consistency_rounds) if self_consistency_enabled else 1
186
+
187
+ # --- Prompt Template Components ---
188
+ self.complexity_keywords = complexity_keywords or list(DEFAULT_COMPLEXITY_KEYWORDS) # Store keywords (currently unused for logic)
189
  self.final_answer_tag = final_answer_tag
190
+ self._cot_instruction = cot_instruction # Customizable CoT instruction
191
+ self._reasoning_header = reasoning_header # Customizable reasoning header
192
+ self._step_prefix = step_prefix # Customizable step prefix (e.g., "Step ")
193
+
194
+ # --- Reliability/Hallucination Mitigation Prompt Components ---
195
+ self._emphasize_factual = emphasize_factual
196
+ self._allow_uncertainty_phrase = allow_uncertainty_phrase
197
+
198
+ # --- Parsing Attributes and Compiled Regex ---
199
+ # Compile regex pattern for final answer extraction based on the specified tag.
200
+ # re.escape handles potential special characters in the tag. re.DOTALL matches newline.
201
  self.final_answer_pattern = re.compile(
202
  re.escape(final_answer_tag) + r"\s*(.*)", re.IGNORECASE | re.DOTALL
203
  )
204
+ self._step_pattern = DEFAULT_STEP_PATTERN # Use the default compiled step pattern
205
+ self._artifact_patterns = strip_artifact_patterns # Patterns for cleaning model output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ logger.debug("Final answer pattern compiled: %s", self.final_answer_pattern.pattern)
208
+ logger.debug("Step pattern: %s", self._step_pattern.pattern)
209
 
210
+ # --- Base Generation Config Setup ---
211
+ # Create or copy the base GenerationConfig. This config holds the default
212
+ # generation parameters that will be used unless overridden during a generate() call.
213
+ # Use .from_dict(.to_dict()) for a clean copy if a config was provided.
214
  if generation_config:
215
+ self.base_generation_config = GenerationConfig.from_dict(generation_config.to_dict())
216
+ logger.info("Initialized with provided base GenerationConfig.")
 
217
  else:
218
+ # Create a default GenerationConfig if none was provided.
219
+ # Incorporate parameters known to work well for CoT based on audit (temp, top_p, top_k).
220
+ # Ensure pad_token_id and eos_token_id are set from the tokenizer (or the fallback).
221
+ self.base_generation_config = GenerationConfig(
222
+ eos_token_id=self.tokenizer.eos_token_id,
223
+ pad_token_id=self.tokenizer.pad_token_id,
224
+ max_length=self.max_length, # Max total length
225
+ do_sample=True, # Always sample for diversity (essential for multi-chain)
226
+ temperature=0.7, # Balanced randomness
227
+ top_p=0.95, # Nucleus sampling
228
+ top_k=50, # Top-k sampling cutoff
229
+ num_return_sequences=1, # Default to 1 sequence (will be overridden by generate call if self-consistency is on)
230
+ # Add a mild repetition penalty, useful for longer CoT
231
+ repetition_penalty=1.1, # Discourage immediate repetition
232
+ no_repeat_ngram_size=0, # Default to no n-gram repetition prevention
233
  )
234
+ logger.info("Initialized with default base GenerationConfig.")
235
+
236
+ # Ensure the base config uses the determined pad_token_id
237
+ # This might be redundant if tokenizer already has it, but ensures consistency
238
+ self.base_generation_config.pad_token_id = self.tokenizer.pad_token_id
239
+ logger.debug("Base GenerationConfig pad_token_id set to %s.", self.base_generation_config.pad_token_id)
240
+
241
+ # Check if the underlying HF model (if found) supports returning scores, useful for CISC.
242
+ # We set this on the model's config if possible, as `generate` reads from there.
243
+ if self._hf_model_instance and hasattr(self._hf_model_instance.config, 'return_dict_in_generate'):
244
+ try:
245
+ # Set these attributes directly on the model's config object
246
+ self._hf_model_instance.config.return_dict_in_generate = True
247
+ self._hf_model_instance.config.output_scores = True # Also request scores
248
+ logger.debug("Set underlying HF model config to return dict in generate and output scores.")
249
+ except Exception as e:
250
+ logger.warning("Failed to set return_dict_in_generate/output_scores on HF model config: %s", e)
251
  else:
252
+ logger.debug("Underlying HF model instance or config does not support setting return_dict_in_generate/output_scores.")
253
 
254
 
255
+ logger.info("ChainOfThoughtWrapper initialization complete.")
256
+ logger.debug("Final Base GenerationConfig: %s", self.base_generation_config.to_dict())
257
 
258
 
259
  def _find_hf_model_and_config(self, obj: Any) -> Tuple[Optional[PreTrainedModel], Optional[Any]]:
260
  """
261
  Recursively searches for an underlying Hugging Face PreTrainedModel
262
+ and its configuration within a potentially wrapped or custom object.
263
+ This helps in accessing standard HF attributes like `config` or
264
+ methods like `resize_token_embeddings`.
265
 
266
  Args:
267
  obj (Any): The object to inspect (could be the model itself or a wrapper).
 
270
  Tuple[Optional[PreTrainedModel], Optional[Any]]: The found HF model instance and its config.
271
  Returns (None, None) if not found.
272
  """
273
+ # Add a check to prevent infinite recursion
274
+ if getattr(obj, '_searching_hf_model', False):
275
+ logger.debug("Preventing infinite recursion in _find_hf_model_and_config for object type: %s", type(obj))
276
+ return None, None
277
+ setattr(obj, '_searching_hf_model', True)
278
+
279
  logger.debug("Searching for HF model in object of type: %s", type(obj))
280
  # If the object is directly a PreTrainedModel and has a config
281
+ if isinstance(obj, PreTrainedModel):
282
  logger.debug("Found HF PreTrainedModel directly.")
283
+ setattr(obj, '_searching_hf_model', False) # Reset flag
284
+ return obj, getattr(obj, 'config', None) # Return config if it exists
285
 
286
  # Check common attribute names where the base model might be stored
287
+ potential_attrs = ('model', 'base_model', 'transformer', '_original_model', 'module') # Added 'module'
288
  for attr_name in potential_attrs:
289
  m = getattr(obj, attr_name, None)
290
  if m is not None:
 
292
  # Recursively search within the attribute
293
  found_model, found_config = self._find_hf_model_and_config(m)
294
  if found_model or found_config:
295
+ setattr(obj, '_searching_hf_model', False) # Reset flag before returning
296
  return found_model, found_config
297
 
298
+ # If no PreTrainedModel found through attributes, check if the object itself has a 'config' attribute
299
  if hasattr(obj, 'config'):
300
+ logger.debug("Found config attribute on object, but no PreTrainedModel instance.")
301
+ setattr(obj, '_searching_hf_model', False) # Reset flag
302
+ return None, obj.config # Return the config found
303
 
304
+ logger.debug("No underlying HF PreTrainedModel instance or config found.")
305
+ setattr(obj, '_searching_hf_model', False) # Reset flag
306
  return None, None
307
 
308
 
309
  def _inject_cot(self, prompt: str) -> str:
310
  """
311
+ Injects the structured Chain-of-Thought template into the user's prompt.
312
+ This template guides the model's response format.
313
+ Incorporates reliability prompts based on settings.
314
 
315
  Args:
316
  prompt (str): The original user prompt.
 
318
  Returns:
319
  str: The prompt with the CoT template appended.
320
  """
321
+ # Start with the cleaned original prompt
322
+ injected_prompt = f"{prompt.strip()}\n\n"
 
 
 
 
 
 
 
 
 
323
 
324
+ # Add the core CoT instruction phrase
325
+ injected_prompt += self._cot_instruction + "\n"
326
 
327
+ # Add reliability-focused instructions if enabled
328
+ if self._emphasize_factual:
329
+ injected_prompt += "Think through the problem step-by-step using only factual information and logical deduction. Do not assume any facts that are not given.\n"
330
+ if self._allow_uncertainty_phrase:
331
+ injected_prompt += self._allow_uncertainty_phrase + "\n"
332
+
333
+ # Add the structured template for reasoning steps and final answer tag
334
+ injected_prompt += f"\n{self._reasoning_header}\n\n"
335
+ injected_prompt += f"{self._step_prefix}1: " # Explicitly start the first step to guide format consistency
336
+
337
+ logger.debug("Injected CoT template. Full prompt starts with: %s...", injected_prompt[:200].replace('\n', '\\n'))
338
+ return injected_prompt
339
+
340
+
341
+ @torch.no_grad() # Disable gradient calculation during generation for efficiency
342
  def generate(
343
  self,
344
+ input_text: str,
345
+ generation_config: Optional[GenerationConfig] = None, # Optional override config for this call
346
+ num_return_sequences: Optional[int] = None, # Explicitly request N sequences
 
 
347
  ) -> Dict[str, Any]:
348
  """
349
+ Generates text using the wrapped model with Chain-of-Thought injection.
350
+ Handles tokenization, prompt injection, generation, and parsing.
351
+ Efficiently generates multiple sequences using `num_return_sequences`.
 
 
352
 
353
  Args:
354
+ input_text (str): The user's input text/question.
355
+ generation_config (Optional[GenerationConfig]): Additional generation parameters
356
+ to override the base config for this call.
357
+ num_return_sequences (Optional[int]): Number of independent sequences (chains) to generate.
358
+ If None, uses the value from the merged generation config
359
+ (defaulting to 1 or `consistency_rounds` if enabled).
 
 
 
 
360
 
361
  Returns:
362
+ Dict[str, Any]: A dictionary containing the generation results:
363
+ - 'sequences': The raw generated token IDs (list of tensors).
364
+ - 'full_texts': List of raw, cleaned text outputs (after stripping prompt/artifacts) for each chain.
365
+ - 'reasoning_steps': List of lists of extracted reasoning steps for each chain.
366
+ - 'final_answers': List of extracted final answer strings for each chain.
367
+ - 'generation_scores': Scores if requested and available (for CISC externally).
368
  """
369
+ logger.info("Received generate call with input text starting: '%s...'", input_text[:100])
370
+
371
+ # 1) Inject the CoT prompt into the original input text
372
+ cot_prompt_text = self._inject_cot(input_text)
373
+
374
+ # 2) Tokenize the full CoT prompt
375
+ # Ensure padding is handled correctly. Use return_tensors="pt" for PyTorch tensors.
376
+ # truncation=True ensures the input fits within max_length.
377
+ # max_length applies to the input sequence here.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  try:
379
+ encoded_input = self.tokenizer(
380
+ cot_prompt_text,
381
+ return_tensors="pt",
382
+ padding="longest", # Pad to the longest sequence in the batch (only 1 here, but good practice)
383
+ truncation=True,
384
+ max_length=self.max_length, # Truncate if the prompt itself is too long
385
  ).to(self.device)
386
+ logger.debug("Input text tokenized. Input IDs shape: %s, on device: %s", encoded_input['input_ids'].shape, encoded_input['input_ids'].device)
 
387
  except Exception as e:
388
+ logger.error("Failed to tokenize input text: %s", e)
389
+ # Attempt cleanup before raising
390
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
391
+ gc.collect()
392
+ raise # Re-raise tokenization error
393
+
394
+
395
+ # 3) Build the final GenerationConfig for this specific call
396
+ # Start with the base config, then merge any provided overrides.
397
+ # Use .from_dict(.to_dict()) for safe merging.
398
+ cfg = GenerationConfig.from_dict(self.base_generation_config.to_dict())
399
+
400
+ if generation_config is not None:
401
+ logger.debug("Merging provided generation_config overrides...")
402
+ cfg.update(**generation_config.to_dict())
403
+ logger.debug("Merged user-provided GenerationConfig.")
404
+
405
+ # Explicitly set num_return_sequences for this call based on the argument.
406
+ # This overrides any num_return_sequences set in the base config or the provided override config.
407
+ if num_return_sequences is not None:
408
+ cfg.num_return_sequences = num_return_sequences
409
+ logger.debug("Using num_return_sequences from function argument: %s", cfg.num_return_sequences)
410
+ elif self.self_consistency_enabled:
411
+ # Fallback: If num_return_sequences argument is None, use consistency_rounds if self_consistency is enabled
412
+ cfg.num_return_sequences = self.consistency_rounds
413
+ logger.debug("num_return_sequences argument is None, using consistency_rounds (%s) because self_consistency is enabled.", cfg.num_return_sequences)
414
+ else:
415
+ # Fallback: If num_return_sequences argument is None and self_consistency is disabled, default to 1
416
+ cfg.num_return_sequences = 1
417
+ logger.debug("num_return_sequences argument is None and self_consistency disabled, defaulting to 1.")
418
+
419
+
420
+ # Ensure max_length in the config respects the wrapper's max_length setting
421
+ # max_length in generate() config is the *total* length (input + new tokens)
422
+ # max_new_tokens is the number of *new* tokens generated
423
+ # Prefer max_new_tokens if set, otherwise calculate from max_length
424
+ input_length = encoded_input['input_ids'].shape[1]
425
+ if cfg.max_new_tokens is None:
426
+ # If max_new_tokens is NOT set, ensure the total length does not exceed the wrapper's max_length
427
+ if cfg.max_length is not None:
428
+ # Only adjust cfg.max_length if it's set in the base/override config
429
+ cfg.max_length = min(self.max_length, cfg.max_length)
430
+ else:
431
+ # If neither max_new_tokens nor max_length were set in base/override, use wrapper's max_length
432
+ cfg.max_length = self.max_length
433
+ logger.debug("max_new_tokens not set in config. Using total max_length: %s (Input length: %s)", cfg.max_length, input_length)
434
+ else:
435
+ # If max_new_tokens IS set, the total length will be input_length + max_new_tokens
436
+ # We should check if this effective total length exceeds the wrapper's overall max_length
437
+ effective_total_length = input_length + cfg.max_new_tokens
438
+ if effective_total_length > self.max_length:
439
+ logger.warning("Effective total length (input %d + new %d = %d) exceeds wrapper max_length (%d). Adjusting max_new_tokens.",
440
+ input_length, cfg.max_new_tokens, effective_total_length, self.max_length)
441
+ # Adjust max_new_tokens down to respect the wrapper's limit
442
+ cfg.max_new_tokens = max(0, self.max_length - input_length)
443
+ logger.warning("Adjusted max_new_tokens to %d.", cfg.max_new_tokens)
444
+
445
+ # Ensure pad_token_id and eos_token_id are correctly set in the final config
446
+ # Use tokenizer's IDs as the source of truth
447
+ cfg.pad_token_id = self.tokenizer.pad_token_id
448
+ cfg.eos_token_id = self.tokenizer.eos_token_id
449
+
450
+ logger.debug("Final GenerationConfig for this call after resolving overrides and num_return_sequences: %s", cfg.to_dict())
451
+
452
+ # --- Debugging: Inspect inputs immediately before generation ---
453
+ # ADDED LOGGING HERE TO DIAGNOSE CUDA ERROR
454
+ logger.debug("-" * 30 + " Inputs to model.generate " + "-" * 30)
455
+ logger.debug(" Input Text Snippet: '%s...'", input_text[:100])
456
+ logger.debug(" CoT Prompt Text Snippet: '%s...'", cot_prompt_text[:200].replace('\n', '\\n'))
457
+ logger.debug(" Input IDs shape: %s, dtype: %s, device: %s", encoded_input["input_ids"].shape, encoded_input["input_ids"].dtype, encoded_input["input_ids"].device)
458
+ if encoded_input.get("attention_mask", None) is not None:
459
+ logger.debug(" Attention Mask shape: %s, dtype: %s, device: %s", encoded_input["attention_mask"].shape, encoded_input["attention_mask"].dtype, encoded_input["attention_mask"].device)
460
+ # Log a snippet of the attention mask for inspection (only first batch item, first 20 tokens)
461
+ if encoded_input["attention_mask"].numel() > 0:
462
+ logger.debug(" Attention Mask snippet (first 20): %s", encoded_input["attention_mask"][0, :20].tolist())
463
+ # Check if mask seems valid (contains only 0s and 1s) - might not catch all CUDA errors but helps debug
464
+ if not torch.all((encoded_input["attention_mask"] == 0) | (encoded_input["attention_mask"] == 1)):
465
+ logger.error("!!! Attention mask contains values other than 0 or 1 !!!")
466
+ else:
467
+ logger.warning("!!! No attention mask provided to model.generate !!!")
468
+ logger.debug(" GenerationConfig.pad_token_id: %s", cfg.pad_token_id)
469
+ logger.debug(" GenerationConfig.eos_token_id: %s", cfg.eos_token_id)
470
+ logger.debug(" GenerationConfig.num_return_sequences: %s", cfg.num_return_sequences)
471
+ logger.debug("-" * 30 + " End Inputs to model.generate " + "-" * 30)
472
+ # --- End Debugging ---
473
 
474
 
475
+ # 4) Generate text using the model's generate method
476
+ # Pass input_ids and attention_mask. Pass the *final* GenerationConfig object.
 
477
  try:
478
+ generation_output = self.model.generate(
479
+ input_ids=encoded_input["input_ids"],
480
+ attention_mask=encoded_input.get("attention_mask", None),
481
+ generation_config=cfg, # Pass the fully configured GenerationConfig
482
+ # Request scores if supported by the model/config for potential CISC implementation externally
483
+ return_dict_in_generate=True, # Request dict output
484
+ output_scores=True, # Request scores
485
  )
486
+ generated_sequences = generation_output.sequences
487
+ # If scores were requested and returned, they are available in generation_output.scores
488
+ # These can be used by the caller for CISC voting.
489
+ generation_scores = generation_output.scores if hasattr(generation_output, 'scores') else None
490
+ logger.info("Generation complete. Generated %d sequence(s).", len(generated_sequences))
491
+ if generation_scores:
492
+ logger.debug("Generation scores available (%d scores tensors).", len(generation_scores))
493
 
494
  except Exception as e:
495
  logger.error("Model generation failed: %s", e)
496
+ # Log the exception details
497
+ import traceback
498
+ logger.error(traceback.format_exc()) # Log full traceback
 
 
499
 
500
+ # Attempt cleanup even on failure - this *might* also trigger the CUDA error again,
501
+ # but it's the correct place to *try* to clean up GPU memory associated with the model.
502
+ if torch.cuda.is_available():
503
+ try:
504
+ torch.cuda.empty_cache()
505
+ logger.debug("Attempted torch.cuda.empty_cache() after generation failure.")
506
+ except Exception as cache_e:
507
+ logger.error("Error during cuda empty_cache after generation failure: %s", cache_e)
508
+ gc.collect()
509
+ logger.debug("Attempted gc.collect() after generation failure.")
510
+
511
+ raise # Re-raise generation error
512
+
513
+
514
+ # 5) Decode and Parse the generated sequences
515
+ # Ensure generated_sequences is a list or tensor before decoding
516
+ if not isinstance(generated_sequences, (list, torch.Tensor)) or len(generated_sequences) == 0:
517
+ logger.warning("No sequences generated. Returning empty results.")
518
+ return {
519
+ "sequences": [],
520
+ "full_texts": [],
521
+ "reasoning_steps": [],
522
+ "final_answers": [],
523
+ "generation_scores": None,
524
+ }
525
+
526
+ decoded_outputs = self.tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)
527
+ logger.debug("Batch decoding complete.")
528
+ parsed_results = [self._parse(text, cot_prompt_text) for text in decoded_outputs]
529
+ logger.debug("Parsing complete for %d sequences.", len(parsed_results))
530
+
531
+
532
+ # Unpack the parsed results
533
+ all_steps = [result[0] for result in parsed_results]
534
+ all_final_answers = [result[1] for result in parsed_results]
535
+ full_generated_bodies = [result[2] for result in parsed_results]
536
+
537
+ # 6) Construct and return the results dictionary
538
+ # The actual self-consistency voting logic is handled by the caller,
539
+ # but the wrapper provides the necessary outputs (multiple chains and parsed answers).
540
  return {
541
+ "sequences": generated_sequences, # Raw sequences (token IDs)
542
+ "full_texts": full_generated_bodies, # Cleaned generated text bodies
543
+ "reasoning_steps": all_steps, # Parsed reasoning steps for each chain
544
+ "final_answers": all_final_answers, # Parsed final answer for each chain
545
+ "generation_scores": generation_scores, # Scores if requested and available (for CISC)
 
 
546
  }
547
 
548
 
549
  def _parse(self, text: str, cot_prompt: str) -> Tuple[List[str], str, str]:
550
  """
551
  Parses the generated text to extract reasoning steps and the final answer.
552
+ This is a robust parsing function that handles different formats,
553
+ artifacts, and provides fallback logic for finding the answer.
 
554
 
555
  Args:
556
  text (str): The raw text output from the model for a single chain.
 
560
  Tuple[List[str], str, str]: A tuple containing:
561
  - A list of extracted reasoning step strings.
562
  - The extracted final answer string.
563
+ - The full body of the generated text (after removing the prompt and artifacts).
564
  """
565
+ logger.debug("Starting parsing for a single generated text chunk...")
566
 
567
+ # 1) Remove the exact injected prompt from the beginning of the text.
568
  # This isolates the model's generated continuation.
569
  body = text
570
  if text.startswith(cot_prompt):
571
+ body = text[len(cot_prompt):] # Remove the prefix
572
+ logger.debug("Removed exact CoT prompt (%d characters) from beginning.", len(cot_prompt))
573
+ else:
574
+ logger.warning("Generated text does not start with the injected CoT prompt. Attempting to parse entire text after initial whitespace strip.")
575
+ body = text.lstrip() # Just strip leading whitespace if template wasn't followed
576
+
577
+ # 2) Apply artifact cleanup patterns
578
+ logger.debug("Applying artifact cleanup patterns...")
579
+ original_body_len = len(body)
580
+ cleaned_body = body # Start with body after prompt removal
581
+ for pattern in self._artifact_patterns:
582
+ cleaned_body = pattern.sub("", cleaned_body)
583
+ if len(cleaned_body) < original_body_len:
584
+ logger.debug("Artifact cleanup removed %d characters.", original_body_len - len(cleaned_body))
585
  else:
586
+ logger.debug("No artifacts found matching patterns.")
587
+
588
+ # Ensure body is stripped after cleanup
589
+ cleaned_body = cleaned_body.strip()
590
+ body_lines = [l.strip() for l in cleaned_body.splitlines() if l.strip()] # Split into non-empty, stripped lines
591
+
 
 
 
 
 
 
 
 
 
 
 
592
  steps = [] # List to store extracted steps
593
  final_answer = "" # Variable to store the final answer
594
+ found_final_answer_tagged = False # Flag to track if the specific tag was found
595
 
596
+ # 3) Extract Steps and Final Answer (Primary Method: Tagged Answer)
597
  # Iterate through lines and apply regex patterns.
598
+ # Prioritize finding the explicit final answer tag.
599
+ logger.debug("Attempting to extract steps and final answer using explicit tag '%s'...", self.final_answer_tag)
600
+ for i, line in enumerate(body_lines):
601
+ # Check for the explicit final answer tag pattern first
602
+ final_answer_match = self.final_answer_pattern.search(line)
603
+ if final_answer_match:
604
+ final_answer = final_answer_match.group(1).strip()
605
+ logger.debug("Extracted final answer using explicit tag: '%s'", final_answer[:100])
606
+ found_final_answer_tagged = True
607
+ # Once the tagged answer is found, we can stop processing lines for it
608
+ # We still iterate through ALL lines below to capture all steps BEFORE the tag.
609
+ # No break here because we need to collect steps that might appear after the tag was first encountered on a line.
610
+ # E.g., "Step 1: ... Final_Answer: X Step 2: ..." (unlikely but possible)
611
+ # The logic below ensures we capture steps *before* the final answer.
612
+
613
+
614
+ # Now, iterate through lines AGAIN to collect steps.
615
+ # This second pass ensures we collect steps even if the answer tag was found early.
616
+ # We stop collecting steps once we encounter the line that *contained* the final answer tag,
617
+ # or if we apply a step limit.
618
+ logger.debug("Collecting reasoning steps...")
619
+ for i, line in enumerate(body_lines):
620
+ # Stop collecting steps if we found the final answer tag on this line or a previous one
621
+ # And if we've reached or passed the line where the tag was found (if it was found)
622
+ # This requires knowing the index of the line where the tag was found.
623
+ # A simpler approach: just collect all lines matching step pattern UP TO the first line
624
+ # where the final answer tag was found.
625
+ final_answer_line_index = -1
626
+ for idx, l in enumerate(body_lines):
627
+ if self.final_answer_pattern.search(l):
628
+ final_answer_line_index = idx
629
+ break # Found the first occurrence of the tag
630
+
631
+ if final_answer_line_index != -1 and i >= final_answer_line_index:
632
+ logger.debug("Stopped collecting steps at line index %d because final answer tag was found on line %d.", i, final_answer_line_index)
633
+ break # Stop collecting steps once we reach the line with the answer tag
634
+
635
  # Check for reasoning step pattern
636
+ step_match = self._step_pattern.match(line)
637
  if step_match:
638
+ step_text = step_match.group(1).strip()
639
+ if step_text: # Only add non-empty steps
640
+ steps.append(step_text)
641
+ # logger.debug("Extracted step: '%s'", steps[-1][:50]) # Too verbose usually
642
+ # Stop adding steps if we've reached a defined limit
643
+ if len(steps) >= self.reasoning_steps_limit:
644
+ logger.debug("Reached reasoning steps limit (%d). Stopping step extraction.", self.reasoning_steps_limit)
645
+ break # Stop collecting steps if limit is reached
646
+
647
+
648
+ # 4) Fallback for Final Answer (If Tag Still Not Found)
649
+ # If the explicit final answer tag was not found after both passes, apply fallback heuristics.
650
+ if not found_final_answer_tagged:
651
+ logger.debug("Explicit final answer tag not found. Applying fallback heuristics.")
652
+
653
+ # Fallback: Assume the last non-step line is the answer.
654
+ # Iterate backwards through the processed lines to find the last line that doesn't look like a step.
655
+ # Using the 'body_lines' list after cleanup and stripping.
 
 
 
 
 
 
 
 
 
 
 
656
  last_non_step_line = ""
657
+ for line in reversed(body_lines): # Iterate backwards through non-empty, stripped lines
658
+ if line and not self._step_pattern.match(line):
659
  last_non_step_line = line.strip()
660
+ logger.debug("Fallback: Identified last non-step line: '%s'", last_non_step_line[:100])
661
+ break # Found the last non-step line, stop searching backwards
662
 
663
  if last_non_step_line:
664
  # Check if the last non-step line *contains* the final answer tag,
665
+ # even if it didn't *start* with it or wasn't the line where the tag was first found.
 
666
  fa_match_fallback = self.final_answer_pattern.search(last_non_step_line)
667
  if fa_match_fallback:
668
  final_answer = fa_match_fallback.group(1).strip()
669
+ logger.debug("Fallback found tagged answer in last non-step line: '%s'", final_answer[:100])
670
  else:
671
+ # If no tag in the last non-step line, just use the line itself as the answer
672
  final_answer = last_non_step_line
673
+ logger.debug("Fallback using last non-step line as answer: '%s'", final_answer[:100])
674
  else:
675
+ # If no non-empty or non-step lines were found, the final answer is empty
676
  final_answer = ""
677
+ logger.debug("Fallback: No non-empty or non-step lines found in body. Final answer is empty.")
678
+
679
+ # 5) Basic Post-Parsing Cleanup on Final Answer
680
+ # Remove any trailing punctuation from the final answer, unless it's part of specific symbols (like !?)
681
+ # This helps normalize answers for voting.
682
+ if final_answer:
683
+ # Remove common trailing characters like periods, commas, etc.
684
+ final_answer = re.sub(r'[.,;:]+$', '', final_answer).strip()
685
+ # Remove common leading "Answer: " or similar preambles if they weren't removed by tag matching
686
+ # This needs to be case-insensitive
687
+ final_answer = re.sub(r'^\s*(?:Answer|Result|Output|Final Answer)\s*[:\-]?\s*', '', final_answer, flags=re.IGNORECASE).strip()
688
+ logger.debug("Applied basic post-parsing cleanup to final answer: '%s'", final_answer[:100])
689
 
690
+ # Final check: Ensure steps list doesn't contain the final answer line or text
691
+ # This is a belt-and-suspenders approach as the logic above should prevent it,
692
+ # but safeguards against edge cases where the tag wasn't found but the line
693
+ # looked like a step *and* contained the answer.
694
+ if final_answer and steps:
695
+ # Remove any step that exactly matches the final answer after stripping
696
+ steps = [step for step in steps if step.strip() != final_answer.strip()]
697
+ # Also check if the final answer is contained *within* a step (less likely but possible)
698
+ steps = [step for step in steps if final_answer.strip() not in step.strip()]
699
 
 
700
 
701
+ logger.info("Parsing complete. Steps found: %d, Final Answer: '%s'", len(steps), final_answer[:100])
702
+
703
+ # Return the extracted steps, the final answer, and the cleaned generated body text
704
+ return steps, final_answer, cleaned_body # Return steps, final answer, and the cleaned body text
705
 
706
 
707
  def resize_token_embeddings(self, new_size: int):
708
  """
709
+ Resizes the model's token embeddings to match a new vocabulary size,
710
+ useful after adding new tokens (like a custom PAD token) to the tokenizer.
711
+ This operation is crucial if the tokenizer size changes and the model
712
+ is used for generation or training.
713
 
714
+ Only works if the underlying model object is a PreTrainedModel
715
+ or has a `resize_token_embeddings` method.
716
 
717
  Args:
718
  new_size (int): The new size of the vocabulary/embedding layer.
719
+ Should typically be `len(self.tokenizer)`.
720
  """
721
+ # Use the stored HF model instance found during initialization
722
+ hf_model_instance = self._hf_model_instance
723
 
724
+ if hf_model_instance and hasattr(hf_model_instance, 'resize_token_embeddings'):
725
  try:
726
  old_size = hf_model_instance.get_input_embeddings().weight.size(0)
727
  if new_size != old_size:
728
+ logger.info("Attempting to resize model token embeddings from %d to %d.", old_size, new_size)
729
+ # Ensure the model is on the correct device before resizing
730
+ hf_model_instance.to(self.device)
731
  hf_model_instance.resize_token_embeddings(new_size)
732
+ logger.info("Successfully resized token embeddings.")
733
  # Update model config's vocab size if available
734
  if hasattr(hf_model_instance, 'config') and hasattr(hf_model_instance.config, 'vocab_size'):
735
+ hf_model_instance.config.vocab_size = new_size
736
+ logger.debug("Updated underlying model config vocab_size to %d.", new_size)
737
+ # Attempt garbage collection after a potentially memory-intensive operation
738
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
739
+ gc.collect()
740
  else:
741
  logger.info("Embedding size is already %d, no resizing needed.", new_size)
742
  except Exception as e:
743
  logger.error("Failed to resize token embeddings: %s", e)
744
+ # Attempt cleanup even on failure
745
  if torch.cuda.is_available(): torch.cuda.empty_cache()
746
  gc.collect()
747
+ # Note: Not re-raising here by default, as a failure might not be critical
748
+ # depending on the user's intended use (e.g., if they don't use the new tokens for generation).
749
+ # Could be re-raised if this is deemed a critical error.
750
  else:
751
+ logger.warning("Cannot resize token embeddings: The underlying model object does not have a 'resize_token_embeddings' method or HF model instance not found.")
752
 
753
 
754
+ # Example Usage (Illustrative)
755
  if __name__ == "__main__":
756
  print("--- ChainOfThoughtWrapper Example Usage ---")
757
+ print("This block demonstrates loading a small HF model and using the wrapper.")
758
+ print("Setting logging level to DEBUG to see detailed wrapper logs.")
759
+ logger.setLevel(logging.DEBUG) # Set logger to DEBUG for example
760
 
761
  # You would replace this with your actual model loading logic
762
  try:
763
  # Use a tiny, fast model for a quick test
764
+ # NOTE: distilgpt2 might still hit CUDA errors with num_return_sequences > 1
765
+ # if there are underlying driver/CUDA/PyTorch compatibility issues or
766
+ # subtle model-specific padding bugs in HF transformers for this architecture.
767
+ # If this example still fails, try a different simple causal model like 'gpt2' or a small LLaMA variant.
768
+ model_id = "distilbert/distilgpt2" # A slightly larger but still fast GPT-2 variant
769
  device = "cuda" if torch.cuda.is_available() else "cpu"
770
 
771
  logger.info(f"Attempting to load model {model_id} on {device}...")
772
+
773
+ # Load tokenizer
774
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
775
 
776
+ # Ensure pad token is set for generation robustness (common requirement for GPT-like models)
777
+ # Handle this *before* loading the model if possible, or ensure embeddings are resized.
778
  if tokenizer.pad_token_id is None:
779
  if tokenizer.eos_token_id is not None:
780
  tokenizer.pad_token_id = tokenizer.eos_token_id
781
+ logger.warning("Tokenizer pad_token_id is None, using eos_token_id (%s) as pad_token_id.", tokenizer.eos_token_id)
782
  else:
783
+ # Add a pad token if neither eos nor pad exists.
784
+ # This *must* be done before loading the model or resizing embeddings.
785
+ logger.warning("Tokenizer pad_token_id and eos_token_id are both None. Adding a [PAD] token.")
786
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
 
787
  tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
788
+ logger.info("Added new [PAD] token with ID %s.", tokenizer.pad_token_id)
789
+ # Note: Resizing embeddings will be handled by the wrapper during initialization
790
+ # if a compatible HF model instance is found.
791
+
792
+
793
+ # Load model
794
+ model = AutoModelForCausalLM.from_pretrained(model_id)
795
+
796
 
797
  # Instantiate the wrapper
798
+ # Simulate parameters that would come from a GUI or config
799
+ # This GenerationConfig will override some defaults in the wrapper's base config for this call.
800
+ simulated_base_gen_config = GenerationConfig(
801
+ max_new_tokens=128, # Limit generated tokens
802
+ temperature=0.85, # Slightly higher temp for diversity in multiple chains
803
+ do_sample=True, # Crucial for sampling-based generation
804
+ # num_return_sequences is intentionally NOT set here; it's set by the wrapper based on generate() argument
805
  pad_token_id=tokenizer.pad_token_id, # Pass pad_token_id explicitly
806
  eos_token_id=tokenizer.eos_token_id, # Pass eos_token_id explicitly
807
+ # Add other parameters based on tuning recommendations if desired
808
+ repetition_penalty=1.1 # Apply repetition penalty
809
  )
810
 
811
+ # Instantiate the wrapper, enabling self-consistency flags in init
812
+ # These flags inform the wrapper's default behavior if generate() args are None
813
  cot_wrapper = ChainOfThoughtWrapper(
814
  model=model,
815
  tokenizer=tokenizer,
816
+ generation_config=simulated_base_gen_config, # Pass overrides here if desired as base
817
  device=device,
818
+ self_consistency_enabled=True, # Simulate SC enabled
819
+ consistency_rounds=5, # Simulate consistency rounds setting
820
+ final_answer_tag="Final Answer:", # Use a slightly different tag for demo
821
+ # Keep factual emphasis on for demo
822
+ emphasize_factual=True,
823
+ allow_uncertainty_phrase="If you cannot determine a definitive answer, state that.",
824
  )
825
 
826
  # Prepare input prompt
827
+ # Use a prompt that encourages steps and a clear answer
828
+ prompt_text = "If a train travels at 60 mph for 2.5 hours, how far does it travel? Calculate step-by-step."
 
829
  logger.info(f"Generating reasoning for prompt: '{prompt_text}'")
830
 
831
  # Generate outputs
832
+ # We explicitly pass num_return_sequences to the generate call (e.g., from GUI slider)
833
+ num_chains_to_generate = 3 # Simulate GUI setting num_chains slider to 3
834
+ logger.info(f"Calling wrapper.generate() requesting {num_chains_to_generate} chains.")
835
+
836
+ start_time = time.time()
837
  outputs = cot_wrapper.generate(
838
+ input_text=prompt_text,
839
+ # No explicit generation_config override here; uses the base config initialized in the wrapper
840
+ # but you *could* pass overrides like: generation_config=GenerationConfig(temperature=1.0)
841
+ num_return_sequences=num_chains_to_generate, # Pass the desired number of sequences here
842
  )
843
+ end_time = time.time()
844
+ logger.info(f"Generation of {len(outputs.get('sequences', []))} sequences took {end_time - start_time:.2f} seconds.")
845
+
846
+
847
+ # --- Process Results (including simulated Self-Consistency voting logic) ---
848
+ print("\n" + "="*50)
849
+ print("--- Generation Results ---")
850
+ print("="*50)
851
+
852
+ full_texts = outputs.get('full_texts', [])
853
+ reasoning_steps = outputs.get('reasoning_steps', [])
854
+ final_answers_raw = outputs.get('final_answers', []) # Raw answers from wrapper
855
+
856
+ if not full_texts:
857
+ print("No chains were generated or parsed.")
858
+ else:
859
+ for i, (full_text, steps, final_answer_raw) in enumerate(zip(full_texts, reasoning_steps, final_answers_raw)):
860
+ print(f"\n--- Chain {i+1} ---")
861
+ print("Full Text (Cleaned):")
862
+ print(full_text)
863
+ print("\nReasoning Steps Parsed:")
864
+ if steps:
865
+ # Ensure steps is a list before iterating
866
+ steps = steps if isinstance(steps, list) else []
867
+ for j, step in enumerate(steps):
868
+ # Ensure step is a string before printing
869
+ if isinstance(step, str) and step.strip():
870
+ print(f" Step {j+1}: {step.strip()}")
871
+ elif not isinstance(step, str):
872
+ print(f" [Step {j+1} has invalid format]")
873
+ if not steps: # If steps list was empty after checks
874
+ print(" [No steps parsed]")
875
+ else: # If steps was None or not a list initially
876
+ print(" [No steps parsed]")
877
+ print("\nFinal Answer Parsed (Raw):")
878
+ # Ensure raw answer is a string before printing
879
+ display_raw_answer = final_answer_raw if isinstance(final_answer_raw, str) and final_answer_raw.strip() else "[No final answer parsed]"
880
+ print(f" '{display_raw_answer}'")
881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
 
883
  # --- Simulate Self-Consistency Voting (as would be done in GUI) ---
884
+ print("\n" + "="*50)
885
+ print("--- Simple Self-Consistency Voting Simulation ---")
886
+ print("="*50)
887
+
888
+ if final_answers_raw:
889
+ # Perform the actual voting using the helper functions
890
+ consensus_answer, answer_distribution_dict = perform_self_consistency_voting(final_answers_raw)
891
+ answer_distribution = Counter(answer_distribution_dict) # Convert to Counter for display
892
+
893
+ print(f"Raw Answers Submitted for Voting: {final_answers_raw}")
894
+ print(f"Normalized Answers for Voting: {list(answer_distribution_dict.keys())}") # Show unique normalized answers
895
+ print(f"Answer Counts: {dict(answer_distribution)}")
896
+
897
+ if consensus_answer:
898
+ print(f"\nConsensus Answer: '{consensus_answer}'")
899
+ # Get count of the winning normalized answer
900
+ winner_count = answer_distribution.get(normalize_answer(consensus_answer), 0)
901
+ print(f"(Voted by {winner_count} chain(s) out of {len(final_answers_raw)})")
902
+
903
+ # Optional: Check for ties (more sophisticated tie-breaking would go here in a real voter)
904
+ if len(answer_distribution) > 1 and answer_distribution.most_common(2)[0][1] == answer_distribution.most_common(2)[1][1]:
905
+ print("Note: There is a tie for the most common normalized answer.")
906
+
907
+ else:
908
+ print("No valid final answers found for voting.")
909
  else:
910
+ print("No final answers were parsed from any chain for voting.")
911
 
912
 
913
  except Exception as e:
914
+ logger.error("An error occurred during the example usage: %s", e)
915
  import traceback
916
  traceback.print_exc() # Print detailed traceback for the example failure
917
 
918
+ print("\n--- Example Usage End ---")
919
+ # Attempt final cleanup
920
+ if torch.cuda.is_available():
921
+ try:
922
+ torch.cuda.empty_cache()
923
+ print("GPU memory cache cleared.")
924
+ except Exception as cleanup_e:
925
+ print(f"Error during final cuda empty_cache: {cleanup_e}")
926
+ gc.collect()
927
+ print("Garbage collected.")