Kompella Sri Aasrith Souri commited on
Commit
76a1306
·
2 Parent(s): 30ecce6 288c71b

\Resolve merge conflicts in supernova/train.py\n\n- Keep improved parameter formatting for TokenChunkDataset\n- Use standard torch.cuda.amp.GradScaler initialization \n- Implement proper validation with wikitext-2 validation split\n- Maintain consistent code style and comments"

Browse files
supernova/__init__.py CHANGED
@@ -1,6 +1,15 @@
1
- __version__ = "0.1.0"
2
-
3
- from .config import ModelConfig
4
- from .model import SupernovaModel
5
- from .tools import ToolOrchestrator, MathEngine, SerperAPI
6
- from .reasoning_engine import EnhancedReasoningEngine, ReasoningType, ReasoningStep
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.1.0"
2
+
3
+ from .config import ModelConfig
4
+ from .model import SupernovaModel
5
+ from .tokenizer import load_gpt2_tokenizer
6
+ from .data import TokenChunkDataset, load_sources_from_yaml, DataSource
7
+ from .tools import ToolOrchestrator, MathEngine, SerperAPI
8
+ from .reasoning_engine import EnhancedReasoningEngine, ReasoningType, ReasoningStep
9
+
10
+ __all__ = [
11
+ 'ModelConfig', 'SupernovaModel', 'load_gpt2_tokenizer',
12
+ 'TokenChunkDataset', 'load_sources_from_yaml', 'DataSource',
13
+ 'ToolOrchestrator', 'MathEngine', 'SerperAPI',
14
+ 'EnhancedReasoningEngine', 'ReasoningType', 'ReasoningStep'
15
+ ]
supernova/data.py CHANGED
@@ -1,105 +1,121 @@
1
- import random
2
- from dataclasses import dataclass
3
- from typing import Dict, Iterable, Iterator, List, Optional, Tuple
4
-
5
- import torch
6
- from torch.utils.data import IterableDataset
7
- from datasets import load_dataset
8
- from transformers import PreTrainedTokenizerBase
9
- import yaml
10
-
11
-
12
- @dataclass
13
- class DataSource:
14
- name: str
15
- hf_path: str
16
- hf_name: Optional[str]
17
- split: str
18
- text_field: str
19
- weight: int = 1
20
- streaming: bool = True
21
-
22
-
23
- def load_sources_from_yaml(path: str) -> List[DataSource]:
24
- with open(path, "r", encoding="utf-8") as f:
25
- cfg = yaml.safe_load(f)
26
- srcs = []
27
- for s in cfg.get("sources", []):
28
- srcs.append(DataSource(
29
- name=s.get("name"),
30
- hf_path=s.get("hf_path"),
31
- hf_name=s.get("hf_name"),
32
- split=s.get("split", "train"),
33
- text_field=s.get("text_field", "text"),
34
- weight=int(s.get("weight", 1)),
35
- streaming=bool(s.get("streaming", True)),
36
- ))
37
- assert len(srcs) > 0, "No data sources configured"
38
- return srcs
39
-
40
-
41
- def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
42
- iters = []
43
- for s in sources:
44
- ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
45
- iters.append(iter(ds))
46
- return iters
47
-
48
-
49
- def weighted_choice(weights: List[int]) -> int:
50
- total = sum(weights)
51
- r = random.randint(1, total)
52
- acc = 0
53
- for i, w in enumerate(weights):
54
- acc += w
55
- if r <= acc:
56
- return i
57
- return len(weights) - 1
58
-
59
-
60
- class TokenChunkDataset(IterableDataset):
61
- def __init__(
62
- self,
63
- tokenizer: PreTrainedTokenizerBase,
64
- sources: List[DataSource],
65
- seq_len: int,
66
- eos_token_id: Optional[int] = None,
67
- ):
68
- super().__init__()
69
- self.tok = tokenizer
70
- self.sources = sources
71
- self.seq_len = seq_len
72
- self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
73
- self.weights = [max(1, s.weight) for s in sources]
74
-
75
- def _iter_texts(self) -> Iterator[str]:
76
- iters = build_streams(self.sources)
77
- while True:
78
- i = weighted_choice(self.weights)
79
- try:
80
- row = next(iters[i])
81
- except StopIteration:
82
- # restart that iterator if streaming was False
83
- iters[i] = build_streams([self.sources[i]])[0]
84
- row = next(iters[i])
85
- text = row.get(self.sources[i].text_field, None)
86
- if isinstance(text, str) and len(text) > 0:
87
- yield text
88
-
89
- def _iter_token_ids(self) -> Iterator[int]:
90
- for text in self._iter_texts():
91
- ids = self.tok.encode(text)
92
- if self.eos_id is not None:
93
- ids.append(self.eos_id)
94
- for t in ids:
95
- yield t
96
-
97
- def __iter__(self):
98
- buf: List[int] = []
99
- for tok_id in self._iter_token_ids():
100
- buf.append(tok_id)
101
- while len(buf) >= self.seq_len + 1:
102
- x = torch.tensor(buf[: self.seq_len], dtype=torch.long)
103
- y = torch.tensor(buf[1 : self.seq_len + 1], dtype=torch.long)
104
- del buf[: self.seq_len]
105
- yield x, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Iterable, Iterator, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch.utils.data import IterableDataset
7
+ from datasets import load_dataset
8
+ from transformers import PreTrainedTokenizerBase
9
+ import yaml
10
+
11
+ @dataclass
12
+ class DataSource:
13
+ name: str
14
+ hf_path: str
15
+ hf_name: Optional[str]
16
+ split: str
17
+ text_field: str
18
+ weight: int = 1
19
+ streaming: bool = True
20
+
21
+ def load_sources_from_yaml(path: str) -> List[DataSource]:
22
+ with open(path, "r", encoding="utf-8") as f:
23
+ cfg = yaml.safe_load(f)
24
+ srcs = []
25
+ for s in cfg.get("sources", []):
26
+ srcs.append(DataSource(
27
+ name=s.get("name"),
28
+ hf_path=s.get("hf_path"),
29
+ hf_name=s.get("hf_name"),
30
+ split=s.get("split", "train"),
31
+ text_field=s.get("text_field", "text"),
32
+ weight=int(s.get("weight", 1)),
33
+ streaming=bool(s.get("streaming", True)),
34
+ ))
35
+ assert len(srcs) > 0, "No data sources configured"
36
+ return srcs
37
+
38
+ def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
39
+ iters = []
40
+ for s in sources:
41
+ ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
42
+ iters.append(iter(ds))
43
+ return iters
44
+
45
+ def weighted_choice(weights: List[int]) -> int:
46
+ total = sum(weights)
47
+ r = random.randint(1, total)
48
+ acc = 0
49
+ for i, w in enumerate(weights):
50
+ acc += w
51
+ if r <= acc:
52
+ return i
53
+ return len(weights) - 1
54
+
55
+ class TokenChunkDataset(IterableDataset):
56
+ def __init__(
57
+ self,
58
+ tokenizer: PreTrainedTokenizerBase,
59
+ sources: List[DataSource],
60
+ seq_len: int,
61
+ eos_token_id: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.tok = tokenizer
65
+ self.sources = sources
66
+ self.seq_len = seq_len
67
+ self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
68
+ self.weights = [max(1, s.weight) for s in sources]
69
+
70
+ def _iter_texts(self) -> Iterator[str]:
71
+ iters = build_streams(self.sources)
72
+ while True:
73
+ i = weighted_choice(self.weights)
74
+ try:
75
+ row = next(iters[i])
76
+ except StopIteration:
77
+ try:
78
+ ds = load_dataset(
79
+ self.sources[i].hf_path,
80
+ self.sources[i].hf_name,
81
+ split=self.sources[i].split,
82
+ streaming=self.sources[i].streaming
83
+ )
84
+ iters[i] = iter(ds)
85
+ row = next(iters[i])
86
+ except (StopIteration, Exception) as e:
87
+ print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}")
88
+ continue # Skip this iteration and try next source
89
+ text = row.get(self.sources[i].text_field, None)
90
+ if isinstance(text, str) and len(text) > 0:
91
+ yield text
92
+
93
+ def _safe_encode(self, text: str) -> list:
94
+ try:
95
+ return self.tok.encode(text)
96
+ except Exception as e:
97
+ print(f"Encoding error for text: {text[:50]}... Error: {e}")
98
+ return []
99
+
100
+ def _iter_token_ids(self) -> Iterator[int]:
101
+ for text in self._iter_texts():
102
+ ids = self._safe_encode(text)
103
+ if self.eos_id is not None:
104
+ ids.append(self.eos_id)
105
+ for t in ids:
106
+ yield t
107
+
108
+ def __iter__(self):
109
+ buf: List[int] = []
110
+ for tok_id in self._iter_token_ids():
111
+ buf.append(tok_id)
112
+ while len(buf) >= self.seq_len + 1:
113
+ x = torch.tensor(buf[:self.seq_len], dtype=torch.long)
114
+ y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long)
115
+ del buf[:self.seq_len]
116
+ yield x, y
117
+
118
+ def __len__(self):
119
+ # Provide approximate length for progress tracking
120
+ return 1000000 # Large number for streaming datasets
121
+
supernova/reasoning_engine.py CHANGED
@@ -1,315 +1,320 @@
1
- """
2
- Enhanced Reasoning Engine for Supernova AI
3
- Provides sophisticated problem-solving capabilities through structured reasoning,
4
- multi-tool coordination, and knowledge synthesis.
5
- """
6
-
7
- import re
8
- import json
9
- from typing import List, Dict, Any, Optional, Tuple
10
- from dataclasses import dataclass
11
- from enum import Enum
12
-
13
- from .tools import ToolOrchestrator, ToolCall
14
-
15
-
16
- class ReasoningType(Enum):
17
- ANALYTICAL = "analytical"
18
- CREATIVE = "creative"
19
- COMPARATIVE = "comparative"
20
- CAUSAL = "causal"
21
- SEQUENTIAL = "sequential"
22
- EVALUATIVE = "evaluative"
23
-
24
-
25
- @dataclass
26
- class ReasoningStep:
27
- step_number: int
28
- description: str
29
- reasoning_type: ReasoningType
30
- tool_needed: Optional[str] = None
31
- query: Optional[str] = None
32
- result: Optional[str] = None
33
- confidence: float = 0.8
34
-
35
-
36
- @dataclass
37
- class KnowledgeDomain:
38
- domain: str
39
- confidence: float
40
- sources: List[str]
41
- key_facts: List[str]
42
-
43
-
44
- class EnhancedReasoningEngine:
45
- """Advanced reasoning engine that mimics sophisticated AI reasoning patterns."""
46
-
47
- def __init__(self, tool_orchestrator: ToolOrchestrator):
48
- self.tools = tool_orchestrator
49
- self.conversation_context = []
50
- self.domain_expertise = {
51
- 'science': ['physics', 'chemistry', 'biology', 'mathematics', 'astronomy'],
52
- 'technology': ['programming', 'ai', 'computing', 'engineering', 'electronics'],
53
- 'humanities': ['history', 'literature', 'philosophy', 'psychology', 'sociology'],
54
- 'medicine': ['anatomy', 'pharmacology', 'diagnosis', 'treatment', 'research'],
55
- 'business': ['finance', 'management', 'economics', 'marketing', 'strategy'],
56
- 'arts': ['music', 'visual arts', 'design', 'architecture', 'performance']
57
- }
58
-
59
- def analyze_query_complexity(self, query: str) -> Dict[str, Any]:
60
- """Analyze the complexity and requirements of a user query."""
61
- complexity_indicators = {
62
- 'simple': ['what is', 'define', 'who is', 'when did'],
63
- 'moderate': ['how does', 'why does', 'explain', 'compare', 'analyze'],
64
- 'complex': ['evaluate', 'synthesize', 'create', 'design', 'solve for multiple', 'consider all factors']
65
- }
66
-
67
- domains_detected = []
68
- for domain, keywords in self.domain_expertise.items():
69
- if any(keyword in query.lower() for keyword in keywords):
70
- domains_detected.append(domain)
71
-
72
- complexity_level = 'simple'
73
- for level, indicators in complexity_indicators.items():
74
- if any(indicator in query.lower() for indicator in indicators):
75
- complexity_level = level
76
-
77
- requires_multi_step = any(phrase in query.lower() for phrase in [
78
- 'step by step', 'first...then', 'multiple', 'several', 'both', 'compare and contrast'
79
- ])
80
-
81
- return {
82
- 'complexity': complexity_level,
83
- 'domains': domains_detected,
84
- 'multi_step_needed': requires_multi_step,
85
- 'estimated_steps': min(5, len(domains_detected) + (2 if requires_multi_step else 1))
86
- }
87
-
88
- def decompose_complex_query(self, query: str, analysis: Dict[str, Any]) -> List[ReasoningStep]:
89
- """Break down complex queries into manageable reasoning steps."""
90
- steps = []
91
- step_num = 1
92
-
93
- # Step 1: Information Gathering
94
- if analysis['complexity'] in ['moderate', 'complex']:
95
- # Determine if we need current information
96
- if any(term in query.lower() for term in ['current', 'latest', 'recent', 'today', '2024', '2025']):
97
- steps.append(ReasoningStep(
98
- step_number=step_num,
99
- description="Gather current information from web sources",
100
- reasoning_type=ReasoningType.ANALYTICAL,
101
- tool_needed="serper",
102
- query=query
103
- ))
104
- step_num += 1
105
-
106
- # Check if mathematical computation is needed
107
- if any(term in query.lower() for term in ['calculate', 'compute', 'solve', 'derivative', 'integral']):
108
- steps.append(ReasoningStep(
109
- step_number=step_num,
110
- description="Perform mathematical computation",
111
- reasoning_type=ReasoningType.ANALYTICAL,
112
- tool_needed="math_engine",
113
- query=query
114
- ))
115
- step_num += 1
116
-
117
- # Step 2: Domain-specific analysis
118
- for domain in analysis['domains']:
119
- steps.append(ReasoningStep(
120
- step_number=step_num,
121
- description=f"Analyze from {domain} perspective",
122
- reasoning_type=ReasoningType.ANALYTICAL,
123
- tool_needed=None, # Will use model generation with domain context
124
- query=f"From a {domain} perspective: {query}"
125
- ))
126
- step_num += 1
127
-
128
- # Step 3: Synthesis and evaluation
129
- if analysis['complexity'] == 'complex':
130
- steps.append(ReasoningStep(
131
- step_number=step_num,
132
- description="Synthesize information and provide comprehensive analysis",
133
- reasoning_type=ReasoningType.EVALUATIVE,
134
- tool_needed=None,
135
- query=query
136
- ))
137
-
138
- return steps if steps else [ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL, query=query)]
139
-
140
- def execute_reasoning_chain(self, steps: List[ReasoningStep], model, tokenizer) -> List[ReasoningStep]:
141
- """Execute a chain of reasoning steps, using tools and model generation as needed."""
142
- results = []
143
- context_info = []
144
-
145
- for step in steps:
146
- if step.tool_needed:
147
- # Use appropriate tool
148
- tool_call = ToolCall(tool=step.tool_needed, query=step.query)
149
- executed_call = self.tools.execute_tool_call(tool_call)
150
-
151
- if executed_call.result:
152
- step.result = executed_call.result
153
- step.confidence = 0.9
154
- context_info.append(f"{step.description}: {executed_call.result}")
155
- else:
156
- step.result = f"Tool execution failed: {executed_call.error}"
157
- step.confidence = 0.3
158
- else:
159
- # Use model generation with enhanced context
160
- enhanced_context = self._build_enhanced_context(step, context_info)
161
- try:
162
- response = self._generate_with_context(model, tokenizer, enhanced_context, step.query)
163
- step.result = response
164
- step.confidence = 0.7
165
- context_info.append(f"{step.description}: {response}")
166
- except Exception as e:
167
- step.result = f"Generation failed: {str(e)}"
168
- step.confidence = 0.2
169
-
170
- results.append(step)
171
-
172
- return results
173
-
174
- def _build_enhanced_context(self, step: ReasoningStep, context_info: List[str]) -> str:
175
- """Build enhanced context for model generation."""
176
- context_parts = [
177
- "You are Supernova, an advanced AI assistant with deep expertise across multiple domains.",
178
- "Apply sophisticated reasoning and provide comprehensive, nuanced responses.",
179
- ""
180
- ]
181
-
182
- if context_info:
183
- context_parts.extend([
184
- "Previous analysis steps:",
185
- *[f"- {info}" for info in context_info],
186
- ""
187
- ])
188
-
189
- reasoning_guidance = {
190
- ReasoningType.ANALYTICAL: "Analyze systematically, consider multiple factors, and provide evidence-based insights.",
191
- ReasoningType.CREATIVE: "Think creatively, explore innovative solutions, and consider unconventional approaches.",
192
- ReasoningType.COMPARATIVE: "Compare different perspectives, weigh pros and cons, and identify key differences.",
193
- ReasoningType.CAUSAL: "Identify cause-and-effect relationships, trace underlying mechanisms, and explain why things happen.",
194
- ReasoningType.SEQUENTIAL: "Break down into logical steps, show progression, and maintain clear sequencing.",
195
- ReasoningType.EVALUATIVE: "Make judgments based on criteria, assess quality and effectiveness, and provide recommendations."
196
- }
197
-
198
- context_parts.extend([
199
- f"Reasoning approach: {reasoning_guidance.get(step.reasoning_type, 'Provide thorough analysis.')}",
200
- f"Focus area: {step.description}",
201
- ""
202
- ])
203
-
204
- return "\n".join(context_parts)
205
-
206
- def _generate_with_context(self, model, tokenizer, context: str, query: str, max_tokens: int = 400) -> str:
207
- """Generate response using the model with enhanced context."""
208
- full_prompt = f"{context}\nUser Query: {query}\n\nDetailed Response:"
209
-
210
- # Use the existing generate function (simplified version)
211
- model.eval()
212
- device = next(model.parameters()).device
213
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
214
-
215
- with torch.no_grad():
216
- for _ in range(max_tokens):
217
- if input_ids.size(1) >= model.cfg.n_positions:
218
- input_cond = input_ids[:, -model.cfg.n_positions:]
219
- else:
220
- input_cond = input_ids
221
-
222
- logits, _ = model(input_cond)
223
- logits = logits[:, -1, :] / 0.8 # temperature
224
-
225
- # Top-k sampling
226
- v, _ = torch.topk(logits, min(50, logits.size(-1)))
227
- logits[logits < v[:, [-1]]] = -float("Inf")
228
-
229
- probs = torch.softmax(logits, dim=-1)
230
- next_id = torch.multinomial(probs, num_samples=1)
231
- input_ids = torch.cat([input_ids, next_id], dim=1)
232
-
233
- response = tokenizer.decode(input_ids[0].tolist())
234
-
235
- # Extract the response part
236
- if "Detailed Response:" in response:
237
- response = response.split("Detailed Response:", 1)[1].strip()
238
-
239
- return response
240
-
241
- def synthesize_final_response(self, steps: List[ReasoningStep], original_query: str) -> str:
242
- """Synthesize all reasoning steps into a comprehensive final response."""
243
- successful_steps = [step for step in steps if step.result and step.confidence > 0.5]
244
-
245
- if not successful_steps:
246
- return "I apologize, but I encountered difficulties processing your request. Could you please rephrase or provide more specific details?"
247
-
248
- # Build comprehensive response
249
- response_parts = []
250
-
251
- # Add executive summary for complex queries
252
- if len(successful_steps) > 2:
253
- response_parts.append("Here's my comprehensive analysis:")
254
- response_parts.append("")
255
-
256
- # Include results from each step
257
- for step in successful_steps:
258
- if step.tool_needed in ['math_engine', 'serper']:
259
- # Tool results are already well-formatted
260
- response_parts.append(step.result)
261
- else:
262
- # Model-generated responses
263
- response_parts.append(step.result)
264
-
265
- response_parts.append("")
266
-
267
- # Add synthesis for multi-step responses
268
- if len(successful_steps) > 2:
269
- confidence_score = sum(step.confidence for step in successful_steps) / len(successful_steps)
270
-
271
- synthesis_parts = [
272
- "**Key Insights:**",
273
- "• Multiple perspectives have been considered",
274
- f"• Analysis confidence: {confidence_score:.1%}",
275
- "• Both current information and domain expertise were utilized"
276
- ]
277
-
278
- response_parts.extend(synthesis_parts)
279
-
280
- return "\n".join(response_parts).strip()
281
-
282
- def process_complex_query(self, query: str, model, tokenizer) -> str:
283
- """Main method to process complex queries with enhanced reasoning."""
284
- # Analyze query complexity and requirements
285
- analysis = self.analyze_query_complexity(query)
286
-
287
- # For simple queries, use direct processing
288
- if analysis['complexity'] == 'simple' and not analysis['multi_step_needed']:
289
- tool_call = self.tools.route_query(query)
290
- if tool_call:
291
- executed_call = self.tools.execute_tool_call(tool_call)
292
- if executed_call.result:
293
- return executed_call.result
294
-
295
- # Fall back to enhanced model generation
296
- context = self._build_enhanced_context(
297
- ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL),
298
- []
299
- )
300
- return self._generate_with_context(model, tokenizer, context, query)
301
-
302
- # For complex queries, use multi-step reasoning
303
- reasoning_steps = self.decompose_complex_query(query, analysis)
304
- executed_steps = self.execute_reasoning_chain(reasoning_steps, model, tokenizer)
305
-
306
- return self.synthesize_final_response(executed_steps, query)
307
-
308
-
309
- # Import torch and other needed modules here to avoid import issues
310
- import torch
311
- try:
312
- import sympy as sp
313
- import numpy as np
314
- except ImportError:
315
- pass
 
 
 
 
 
 
1
+ """
2
+ Enhanced Reasoning Engine for Supernova AI
3
+ Provides sophisticated problem-solving capabilities through structured reasoning,
4
+ multi-tool coordination, and knowledge synthesis.
5
+ """
6
+ import torch
7
+ import numpy as np
8
+ try:
9
+ import sympy as sp
10
+ except ImportError:
11
+ sp = None
12
+ import re
13
+ import json
14
+ from typing import List, Dict, Any, Optional, Tuple
15
+ from dataclasses import dataclass
16
+ from enum import Enum
17
+
18
+ from .tools import ToolOrchestrator, ToolCall
19
+
20
+
21
+ class ReasoningType(Enum):
22
+ ANALYTICAL = "analytical"
23
+ CREATIVE = "creative"
24
+ COMPARATIVE = "comparative"
25
+ CAUSAL = "causal"
26
+ SEQUENTIAL = "sequential"
27
+ EVALUATIVE = "evaluative"
28
+
29
+
30
+ @dataclass
31
+ class ReasoningStep:
32
+ step_number: int
33
+ description: str
34
+ reasoning_type: ReasoningType
35
+ tool_needed: Optional[str] = None
36
+ query: Optional[str] = None
37
+ result: Optional[str] = None
38
+ confidence: float = 0.8
39
+
40
+
41
+ @dataclass
42
+ class KnowledgeDomain:
43
+ domain: str
44
+ confidence: float
45
+ sources: List[str]
46
+ key_facts: List[str]
47
+
48
+
49
+ class EnhancedReasoningEngine:
50
+ """Advanced reasoning engine that mimics sophisticated AI reasoning patterns."""
51
+
52
+ def __init__(self, tool_orchestrator: ToolOrchestrator):
53
+ self.tools = tool_orchestrator
54
+ self.conversation_context = []
55
+ self.domain_expertise = {
56
+ 'science': ['physics', 'chemistry', 'biology', 'mathematics', 'astronomy'],
57
+ 'technology': ['programming', 'ai', 'computing', 'engineering', 'electronics'],
58
+ 'humanities': ['history', 'literature', 'philosophy', 'psychology', 'sociology'],
59
+ 'medicine': ['anatomy', 'pharmacology', 'diagnosis', 'treatment', 'research'],
60
+ 'business': ['finance', 'management', 'economics', 'marketing', 'strategy'],
61
+ 'arts': ['music', 'visual arts', 'design', 'architecture', 'performance']
62
+ }
63
+
64
+ def analyze_query_complexity(self, query: str) -> Dict[str, Any]:
65
+ """Analyze the complexity and requirements of a user query."""
66
+ complexity_indicators = {
67
+ 'simple': ['what is', 'define', 'who is', 'when did'],
68
+ 'moderate': ['how does', 'why does', 'explain', 'compare', 'analyze'],
69
+ 'complex': ['evaluate', 'synthesize', 'create', 'design', 'solve for multiple', 'consider all factors']
70
+ }
71
+
72
+ domains_detected = []
73
+ for domain, keywords in self.domain_expertise.items():
74
+ if any(keyword in query.lower() for keyword in keywords):
75
+ domains_detected.append(domain)
76
+
77
+ complexity_level = 'simple'
78
+ for level, indicators in complexity_indicators.items():
79
+ if any(indicator in query.lower() for indicator in indicators):
80
+ complexity_level = level
81
+
82
+ requires_multi_step = any(phrase in query.lower() for phrase in [
83
+ 'step by step', 'first...then', 'multiple', 'several', 'both', 'compare and contrast'
84
+ ])
85
+
86
+ return {
87
+ 'complexity': complexity_level,
88
+ 'domains': domains_detected,
89
+ 'multi_step_needed': requires_multi_step,
90
+ 'estimated_steps': min(5, len(domains_detected) + (2 if requires_multi_step else 1))
91
+ }
92
+
93
+ def decompose_complex_query(self, query: str, analysis: Dict[str, Any]) -> List[ReasoningStep]:
94
+ """Break down complex queries into manageable reasoning steps."""
95
+ steps = []
96
+ step_num = 1
97
+
98
+ # Step 1: Information Gathering
99
+ if analysis['complexity'] in ['moderate', 'complex']:
100
+ # Determine if we need current information
101
+ if any(term in query.lower() for term in ['current', 'latest', 'recent', 'today', '2024', '2025']):
102
+ steps.append(ReasoningStep(
103
+ step_number=step_num,
104
+ description="Gather current information from web sources",
105
+ reasoning_type=ReasoningType.ANALYTICAL,
106
+ tool_needed="serper",
107
+ query=query
108
+ ))
109
+ step_num += 1
110
+
111
+ # Check if mathematical computation is needed
112
+ if any(term in query.lower() for term in ['calculate', 'compute', 'solve', 'derivative', 'integral']):
113
+ steps.append(ReasoningStep(
114
+ step_number=step_num,
115
+ description="Perform mathematical computation",
116
+ reasoning_type=ReasoningType.ANALYTICAL,
117
+ tool_needed="math_engine",
118
+ query=query
119
+ ))
120
+ step_num += 1
121
+
122
+ # Step 2: Domain-specific analysis
123
+ for domain in analysis['domains']:
124
+ steps.append(ReasoningStep(
125
+ step_number=step_num,
126
+ description=f"Analyze from {domain} perspective",
127
+ reasoning_type=ReasoningType.ANALYTICAL,
128
+ tool_needed=None, # Will use model generation with domain context
129
+ query=f"From a {domain} perspective: {query}"
130
+ ))
131
+ step_num += 1
132
+
133
+ # Step 3: Synthesis and evaluation
134
+ if analysis['complexity'] == 'complex':
135
+ steps.append(ReasoningStep(
136
+ step_number=step_num,
137
+ description="Synthesize information and provide comprehensive analysis",
138
+ reasoning_type=ReasoningType.EVALUATIVE,
139
+ tool_needed=None,
140
+ query=query
141
+ ))
142
+
143
+ return steps if steps else [ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL, query=query)]
144
+
145
+ def execute_reasoning_chain(self, steps: List[ReasoningStep], model, tokenizer) -> List[ReasoningStep]:
146
+ """Execute a chain of reasoning steps, using tools and model generation as needed."""
147
+ results = []
148
+ context_info = []
149
+
150
+ for step in steps:
151
+ if step.tool_needed:
152
+ # Use appropriate tool
153
+ tool_call = ToolCall(tool=step.tool_needed, query=step.query)
154
+ executed_call = self.tools.execute_tool_call(tool_call)
155
+
156
+ if executed_call.result:
157
+ step.result = executed_call.result
158
+ step.confidence = 0.9
159
+ context_info.append(f"{step.description}: {executed_call.result}")
160
+ else:
161
+ step.result = f"Tool execution failed: {executed_call.error}"
162
+ step.confidence = 0.3
163
+ else:
164
+ # Use model generation with enhanced context
165
+ enhanced_context = self._build_enhanced_context(step, context_info)
166
+ try:
167
+ response = self._generate_with_context(model, tokenizer, enhanced_context, step.query)
168
+ step.result = response
169
+ step.confidence = 0.7
170
+ context_info.append(f"{step.description}: {response}")
171
+ except Exception as e:
172
+ step.result = f"Generation failed: {str(e)}"
173
+ step.confidence = 0.2
174
+
175
+ results.append(step)
176
+
177
+ return results
178
+
179
+ def _build_enhanced_context(self, step: ReasoningStep, context_info: List[str]) -> str:
180
+ """Build enhanced context for model generation."""
181
+ context_parts = [
182
+ "You are Supernova, an advanced AI assistant with deep expertise across multiple domains.",
183
+ "Apply sophisticated reasoning and provide comprehensive, nuanced responses.",
184
+ ""
185
+ ]
186
+
187
+ if context_info:
188
+ context_parts.extend([
189
+ "Previous analysis steps:",
190
+ *[f"- {info}" for info in context_info],
191
+ ""
192
+ ])
193
+
194
+ reasoning_guidance = {
195
+ ReasoningType.ANALYTICAL: "Analyze systematically, consider multiple factors, and provide evidence-based insights.",
196
+ ReasoningType.CREATIVE: "Think creatively, explore innovative solutions, and consider unconventional approaches.",
197
+ ReasoningType.COMPARATIVE: "Compare different perspectives, weigh pros and cons, and identify key differences.",
198
+ ReasoningType.CAUSAL: "Identify cause-and-effect relationships, trace underlying mechanisms, and explain why things happen.",
199
+ ReasoningType.SEQUENTIAL: "Break down into logical steps, show progression, and maintain clear sequencing.",
200
+ ReasoningType.EVALUATIVE: "Make judgments based on criteria, assess quality and effectiveness, and provide recommendations."
201
+ }
202
+
203
+ context_parts.extend([
204
+ f"Reasoning approach: {reasoning_guidance.get(step.reasoning_type, 'Provide thorough analysis.')}",
205
+ f"Focus area: {step.description}",
206
+ ""
207
+ ])
208
+
209
+ return "\n".join(context_parts)
210
+
211
+ def _generate_with_context(self, model, tokenizer, context: str, query: str, max_tokens: int = 400) -> str:
212
+ """Generate response using the model with enhanced context."""
213
+ full_prompt = f"{context}\nUser Query: {query}\n\nDetailed Response:"
214
+
215
+ # Use the existing generate function (simplified version)
216
+ model.eval()
217
+ device = next(model.parameters()).device
218
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
219
+
220
+ with torch.no_grad():
221
+ for _ in range(max_tokens):
222
+ if input_ids.size(1) >= model.cfg.n_positions:
223
+ input_cond = input_ids[:, -model.cfg.n_positions:]
224
+ else:
225
+ input_cond = input_ids
226
+
227
+ logits, _ = model(input_cond)
228
+ logits = logits[:, -1, :] / 0.8 # temperature
229
+
230
+ # Top-k sampling
231
+ v, _ = torch.topk(logits, min(50, logits.size(-1)))
232
+ logits[logits < v[:, [-1]]] = -float("Inf")
233
+
234
+ probs = torch.softmax(logits, dim=-1)
235
+ next_id = torch.multinomial(probs, num_samples=1)
236
+ input_ids = torch.cat([input_ids, next_id], dim=1)
237
+
238
+ response = tokenizer.decode(input_ids[0].tolist())
239
+
240
+ # Extract the response part
241
+ if "Detailed Response:" in response:
242
+ response = response.split("Detailed Response:", 1)[1].strip()
243
+
244
+ return response
245
+
246
+ def synthesize_final_response(self, steps: List[ReasoningStep], original_query: str) -> str:
247
+ """Synthesize all reasoning steps into a comprehensive final response."""
248
+ successful_steps = [step for step in steps if step.result and step.confidence > 0.5]
249
+
250
+ if not successful_steps:
251
+ return "I apologize, but I encountered difficulties processing your request. Could you please rephrase or provide more specific details?"
252
+
253
+ # Build comprehensive response
254
+ response_parts = []
255
+
256
+ # Add executive summary for complex queries
257
+ if len(successful_steps) > 2:
258
+ response_parts.append("Here's my comprehensive analysis:")
259
+ response_parts.append("")
260
+
261
+ # Include results from each step
262
+ for step in successful_steps:
263
+ if step.tool_needed in ['math_engine', 'serper']:
264
+ # Tool results are already well-formatted
265
+ response_parts.append(step.result)
266
+ else:
267
+ # Model-generated responses
268
+ response_parts.append(step.result)
269
+
270
+ response_parts.append("")
271
+
272
+ # Add synthesis for multi-step responses
273
+ if len(successful_steps) > 2:
274
+ confidence_score = sum(step.confidence for step in successful_steps) / len(successful_steps)
275
+
276
+ synthesis_parts = [
277
+ "**Key Insights:**",
278
+ "• Multiple perspectives have been considered",
279
+ f"• Analysis confidence: {confidence_score:.1%}",
280
+ "• Both current information and domain expertise were utilized"
281
+ ]
282
+
283
+ response_parts.extend(synthesis_parts)
284
+
285
+ return "\n".join(response_parts).strip()
286
+
287
+ def process_complex_query(self, query: str, model, tokenizer) -> str:
288
+ """Main method to process complex queries with enhanced reasoning."""
289
+ # Analyze query complexity and requirements
290
+ analysis = self.analyze_query_complexity(query)
291
+
292
+ # For simple queries, use direct processing
293
+ if analysis['complexity'] == 'simple' and not analysis['multi_step_needed']:
294
+ tool_call = self.tools.route_query(query)
295
+ if tool_call:
296
+ executed_call = self.tools.execute_tool_call(tool_call)
297
+ if executed_call.result:
298
+ return executed_call.result
299
+
300
+ # Fall back to enhanced model generation
301
+ context = self._build_enhanced_context(
302
+ ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL),
303
+ []
304
+ )
305
+ return self._generate_with_context(model, tokenizer, context, query)
306
+
307
+ # For complex queries, use multi-step reasoning
308
+ reasoning_steps = self.decompose_complex_query(query, analysis)
309
+ executed_steps = self.execute_reasoning_chain(reasoning_steps, model, tokenizer)
310
+
311
+ return self.synthesize_final_response(executed_steps, query)
312
+
313
+
314
+ # Import torch and other needed modules here to avoid import issues
315
+ import torch
316
+ try:
317
+ import sympy as sp
318
+ import numpy as np
319
+ except ImportError:
320
+ pass
supernova/train.py CHANGED
@@ -1,4 +1,3 @@
1
- # train.py (improved)
2
  import argparse
3
  import json
4
  import math
@@ -15,11 +14,11 @@ from transformers import get_cosine_schedule_with_warmup
15
  from .config import ModelConfig
16
  from .model import SupernovaModel
17
  from .tokenizer import load_gpt2_tokenizer
18
- from .data import load_sources_from_yaml, TokenChunkDataset
19
 
20
- # -----------------------
21
  # Utilities
22
- # -----------------------
23
  def compute_grad_norm(model: nn.Module) -> float:
24
  total = 0.0
25
  for p in model.parameters():
@@ -61,9 +60,9 @@ class EMA:
61
  p.data.copy_(self.backup[name])
62
  del self.backup
63
 
64
- # -----------------------
65
  # Training loop
66
- # -----------------------
67
  def train(
68
  config_path: str,
69
  data_config_path: str,
@@ -145,12 +144,12 @@ def train(
145
  seq_len=seq_len,
146
  eos_token_id=tok.eos_token_id
147
  )
148
-
149
  sampler = DistributedSampler(ds) if ddp else None
 
 
150
  dl = DataLoader(
151
  ds,
152
  batch_size=batch_size,
153
- shuffle=(sampler is None),
154
  sampler=sampler,
155
  num_workers=num_workers,
156
  pin_memory=pin_memory,
@@ -158,7 +157,7 @@ def train(
158
  drop_last=True,
159
  )
160
 
161
- # optimizer with simple parameter grouping example to avoid weight decay on norms/bias
162
  def param_groups(model):
163
  decay, no_decay = [], []
164
  for n, p in model.named_parameters():
@@ -174,25 +173,17 @@ def train(
174
  ]
175
 
176
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
177
-
178
- # scheduler
179
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
180
-
181
  # AMP scaler
182
- if device.type == "cuda":
183
- scaler = torch.amp.GradScaler('cuda', enabled=True)
184
- else:
185
- scaler = torch.amp.GradScaler('cpu', enabled=False)
186
 
187
  # EMA
188
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
189
 
190
- # logging + checkpoint dir
191
  os.makedirs(out_dir, exist_ok=True)
192
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
193
 
194
- # validation dataset (simple split: user should provide a separate validation YAML ideally)
195
- # TODO: Implement a proper validation dataset pipeline. For now, we use a small random subset of training data.
196
  val_ds = None
197
  val_dl = None
198
 
@@ -202,7 +193,6 @@ def train(
202
  if resume_from and os.path.exists(resume_from):
203
  ckpt = torch.load(resume_from, map_location=device)
204
  model_state = ckpt["model_state_dict"]
205
- # if ddp, load into module
206
  target = model.module if ddp else model
207
  target.load_state_dict(model_state)
208
  optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
@@ -221,12 +211,12 @@ def train(
221
  running_loss = 0.0
222
  t0 = time.time()
223
  no_improve_steps = 0
224
- early_stop_patience = 10_000 # you can tune this
225
 
226
  # training loop
227
  while step < max_steps:
228
  if sampler is not None:
229
- sampler.set_epoch(step) # shuffle differently per epoch for DDP
230
 
231
  for batch in dl:
232
  x, y = batch
@@ -243,7 +233,6 @@ def train(
243
  running_loss += loss.item()
244
 
245
  if micro % grad_accum == 0:
246
- # gradient clipping
247
  if clip_grad_norm is not None:
248
  scaler.unscale_(optimizer)
249
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
@@ -255,7 +244,6 @@ def train(
255
 
256
  if ema:
257
  ema.update(model if not ddp else model.module)
258
-
259
  step += 1
260
 
261
  # logging
@@ -275,11 +263,20 @@ def train(
275
  # periodic validation
276
  if validate_every and step % validate_every == 0:
277
  if val_dl is None:
278
- # quick in-memory val split: take first N batches (user should replace with real val)
279
- # NOTE: for production, create a dedicated validation dataset.
280
- val_sources = sources[: max(1, len(sources) // 20)]
281
- if not val_sources:
282
- val_sources = sources[:1] # fallback to at least one source
 
 
 
 
 
 
 
 
 
283
  val_ds = TokenChunkDataset(
284
  tokenizer=tok,
285
  sources=val_sources,
@@ -289,7 +286,6 @@ def train(
289
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
290
 
291
  model.eval()
292
- # optionally swap in EMA weights for evaluation
293
  if ema:
294
  ema.store(model if not ddp else model.module)
295
  ema.copy_to(model if not ddp else model.module)
@@ -310,12 +306,10 @@ def train(
310
  writer.add_scalar("val/loss", mean_val, step)
311
  print(f"[eval] step={step} val_loss={mean_val:.6f}")
312
 
313
- # restore weights
314
  if ema:
315
  ema.restore(model if not ddp else model.module)
316
  model.train()
317
 
318
- # early stop / best model saving
319
  if mean_val < best_val_loss:
320
  best_val_loss = mean_val
321
  no_improve_steps = 0
@@ -331,7 +325,7 @@ def train(
331
  }
332
  if not ddp or local_rank == 0:
333
  atomic_save(ckpt, best_path)
334
- print(f"Saved best checkpoint to {best_path}")
335
  else:
336
  no_improve_steps += validate_every
337
  if no_improve_steps >= early_stop_patience:
@@ -378,7 +372,6 @@ def train(
378
  if writer:
379
  writer.close()
380
 
381
-
382
  if __name__ == "__main__":
383
  ap = argparse.ArgumentParser()
384
  ap.add_argument("--config", required=True)
 
 
1
  import argparse
2
  import json
3
  import math
 
14
  from .config import ModelConfig
15
  from .model import SupernovaModel
16
  from .tokenizer import load_gpt2_tokenizer
17
+ from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
18
 
19
+ # ------------------------------
20
  # Utilities
21
+ # ------------------------------
22
  def compute_grad_norm(model: nn.Module) -> float:
23
  total = 0.0
24
  for p in model.parameters():
 
60
  p.data.copy_(self.backup[name])
61
  del self.backup
62
 
63
+ # ------------------------------
64
  # Training loop
65
+ # ------------------------------
66
  def train(
67
  config_path: str,
68
  data_config_path: str,
 
144
  seq_len=seq_len,
145
  eos_token_id=tok.eos_token_id
146
  )
 
147
  sampler = DistributedSampler(ds) if ddp else None
148
+
149
+ # NOTE: NO shuffle for IterableDataset!
150
  dl = DataLoader(
151
  ds,
152
  batch_size=batch_size,
 
153
  sampler=sampler,
154
  num_workers=num_workers,
155
  pin_memory=pin_memory,
 
157
  drop_last=True,
158
  )
159
 
160
+ # optimizer
161
  def param_groups(model):
162
  decay, no_decay = [], []
163
  for n, p in model.named_parameters():
 
173
  ]
174
 
175
  optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
 
 
176
  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
 
177
  # AMP scaler
178
+ scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
 
 
 
179
 
180
  # EMA
181
  ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
182
 
 
183
  os.makedirs(out_dir, exist_ok=True)
184
  writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
185
 
186
+ # validation
 
187
  val_ds = None
188
  val_dl = None
189
 
 
193
  if resume_from and os.path.exists(resume_from):
194
  ckpt = torch.load(resume_from, map_location=device)
195
  model_state = ckpt["model_state_dict"]
 
196
  target = model.module if ddp else model
197
  target.load_state_dict(model_state)
198
  optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
 
211
  running_loss = 0.0
212
  t0 = time.time()
213
  no_improve_steps = 0
214
+ early_stop_patience = 10_000 # you can tune this
215
 
216
  # training loop
217
  while step < max_steps:
218
  if sampler is not None:
219
+ sampler.set_epoch(step) # shuffle differently per epoch for DDP
220
 
221
  for batch in dl:
222
  x, y = batch
 
233
  running_loss += loss.item()
234
 
235
  if micro % grad_accum == 0:
 
236
  if clip_grad_norm is not None:
237
  scaler.unscale_(optimizer)
238
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
 
244
 
245
  if ema:
246
  ema.update(model if not ddp else model.module)
 
247
  step += 1
248
 
249
  # logging
 
263
  # periodic validation
264
  if validate_every and step % validate_every == 0:
265
  if val_dl is None:
266
+ # Use a proper validation dataset with wikitext-2 validation split
267
+ # This provides more reliable validation than using training data subsets
268
+ val_sources = []
269
+ for source in sources[:min(3, len(sources))]:
270
+ val_source = DataSource(
271
+ name=f"{source.name}_val",
272
+ hf_path="wikitext",
273
+ hf_name="wikitext-2-v1",
274
+ split="validation",
275
+ text_field="text",
276
+ weight=1,
277
+ streaming=False
278
+ )
279
+ val_sources.append(val_source)
280
  val_ds = TokenChunkDataset(
281
  tokenizer=tok,
282
  sources=val_sources,
 
286
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
287
 
288
  model.eval()
 
289
  if ema:
290
  ema.store(model if not ddp else model.module)
291
  ema.copy_to(model if not ddp else model.module)
 
306
  writer.add_scalar("val/loss", mean_val, step)
307
  print(f"[eval] step={step} val_loss={mean_val:.6f}")
308
 
 
309
  if ema:
310
  ema.restore(model if not ddp else model.module)
311
  model.train()
312
 
 
313
  if mean_val < best_val_loss:
314
  best_val_loss = mean_val
315
  no_improve_steps = 0
 
325
  }
326
  if not ddp or local_rank == 0:
327
  atomic_save(ckpt, best_path)
328
+ print(f"Saved best checkpoint to {best_path}")
329
  else:
330
  no_improve_steps += validate_every
331
  if no_improve_steps >= early_stop_patience:
 
372
  if writer:
373
  writer.close()
374
 
 
375
  if __name__ == "__main__":
376
  ap = argparse.ArgumentParser()
377
  ap.add_argument("--config", required=True)