Specific-Cognito commited on
Commit
fc7a401
·
verified ·
1 Parent(s): bcf2cb0

Update helion_orchestrator.py

Browse files
Files changed (1) hide show
  1. helion_orchestrator.py +211 -7
helion_orchestrator.py CHANGED
@@ -22,6 +22,107 @@ logging.basicConfig(
22
  logger = logging.getLogger(__name__)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @dataclass
26
  class HelionConfig:
27
  """Central configuration for all Helion operations."""
@@ -50,6 +151,11 @@ class HelionConfig:
50
  enable_safeguards: bool = True
51
  enable_tools: bool = False
52
 
 
 
 
 
 
53
  # HuggingFace
54
  hf_token: Optional[str] = None
55
  push_to_hub: bool = False
@@ -74,8 +180,13 @@ class HelionOrchestrator:
74
  self.tokenizer = None
75
  self.safeguards = None
76
  self.tool_system = None
 
77
 
78
  self.session_log = []
 
 
 
 
79
 
80
  # ==================== Model Loading ====================
81
 
@@ -169,6 +280,14 @@ class HelionOrchestrator:
169
  except ImportError:
170
  logger.warning("Tools module not found")
171
 
 
 
 
 
 
 
 
 
172
  def unload_model(self):
173
  """Unload model to free memory."""
174
  if self.model:
@@ -185,7 +304,9 @@ class HelionOrchestrator:
185
  max_tokens: Optional[int] = None,
186
  temperature: Optional[float] = None,
187
  system_prompt: Optional[str] = None,
188
- use_safeguards: bool = True
 
 
189
  ) -> Dict[str, Any]:
190
  """
191
  Generate response from prompt.
@@ -196,6 +317,8 @@ class HelionOrchestrator:
196
  temperature: Sampling temperature
197
  system_prompt: Optional system prompt
198
  use_safeguards: Apply safeguard checks
 
 
199
 
200
  Returns:
201
  Dict with response and metadata
@@ -206,10 +329,20 @@ class HelionOrchestrator:
206
  max_tokens = max_tokens or self.config.max_tokens
207
  temperature = temperature or self.config.temperature
208
 
 
 
 
 
 
209
  # Build messages
210
  messages = []
211
  if system_prompt:
212
  messages.append({"role": "system", "content": system_prompt})
 
 
 
 
 
213
  messages.append({"role": "user", "content": prompt})
214
 
215
  # Check with safeguards
@@ -245,12 +378,17 @@ class HelionOrchestrator:
245
  skip_special_tokens=True
246
  ).strip()
247
 
 
 
 
 
248
  result = {
249
  "response": response_text,
250
  "blocked": False,
251
  "prompt_tokens": input_ids.shape[1],
252
  "completion_tokens": output.shape[1] - input_ids.shape[1],
253
- "total_tokens": output.shape[1]
 
254
  }
255
 
256
  self._log_event("generation", {"prompt": prompt[:100], "tokens": result["total_tokens"]})
@@ -259,6 +397,8 @@ class HelionOrchestrator:
259
  def chat(
260
  self,
261
  messages: List[Dict[str, str]],
 
 
262
  **kwargs
263
  ) -> Dict[str, Any]:
264
  """
@@ -266,6 +406,8 @@ class HelionOrchestrator:
266
 
267
  Args:
268
  messages: List of message dicts
 
 
269
  **kwargs: Generation parameters
270
 
271
  Returns:
@@ -274,6 +416,15 @@ class HelionOrchestrator:
274
  if not self.model:
275
  raise RuntimeError("Model not loaded")
276
 
 
 
 
 
 
 
 
 
 
277
  # Similar to generate but maintains conversation
278
  input_ids = self.tokenizer.apply_chat_template(
279
  messages,
@@ -297,6 +448,11 @@ class HelionOrchestrator:
297
  skip_special_tokens=True
298
  ).strip()
299
 
 
 
 
 
 
300
  return {"response": response, "blocked": False}
301
 
302
  def interactive_chat(self):
@@ -306,11 +462,20 @@ class HelionOrchestrator:
306
  return
307
 
308
  print("\n" + "="*60)
309
- print("Helion Interactive Chat")
310
- print("Commands: /quit, /clear, /save, /load, /help")
311
  print("="*60 + "\n")
312
 
313
  conversation = []
 
 
 
 
 
 
 
 
 
314
 
315
  while True:
316
  try:
@@ -322,12 +487,24 @@ class HelionOrchestrator:
322
  # Handle commands
323
  if user_input.startswith("/"):
324
  if user_input == "/quit":
 
 
325
  print("Goodbye!")
326
  break
327
  elif user_input == "/clear":
328
  conversation = []
329
  print("Conversation cleared.")
330
  continue
 
 
 
 
 
 
 
 
 
 
331
  elif user_input.startswith("/save"):
332
  self._save_conversation(conversation, user_input.split()[1] if len(user_input.split()) > 1 else None)
333
  continue
@@ -340,7 +517,11 @@ class HelionOrchestrator:
340
 
341
  conversation.append({"role": "user", "content": user_input})
342
 
343
- result = self.chat(conversation)
 
 
 
 
344
 
345
  if result.get("blocked"):
346
  print(f"🤖 Helion: {result['response']}")
@@ -675,6 +856,7 @@ CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8000"]
675
  "device": str(self.model.device) if self.model else None,
676
  "safeguards_enabled": self.safeguards is not None,
677
  "tools_enabled": self.tool_system is not None,
 
678
  "config": asdict(self.config),
679
  "session_events": len(self.session_log)
680
  }
@@ -682,8 +864,28 @@ CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8000"]
682
  if self.model:
683
  info["model_memory"] = torch.cuda.max_memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
684
 
 
 
 
 
685
  return info
686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  def _log_event(self, event_type: str, data: Dict[str, Any]):
688
  """Log orchestrator event."""
689
  event = {
@@ -712,9 +914,11 @@ CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8000"]
712
  """Print chat help."""
713
  print("""
714
  Available Commands:
715
- /quit - Exit chat
716
- /clear - Clear conversation history
717
  /save [name] - Save conversation to file
 
 
718
  /help - Show this help message
719
  """)
720
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
+ class MemoryManager:
26
+ """
27
+ Conversation memory manager for Helion.
28
+ Stores and retrieves conversation history for context-aware responses.
29
+ """
30
+
31
+ def __init__(self, memory_file: str = "helion_memory.json", window_size: int = 10):
32
+ self.memory_file = Path(memory_file)
33
+ self.window_size = window_size
34
+ self.conversations: Dict[str, List[Dict]] = {}
35
+ self.load()
36
+
37
+ def add_interaction(self, conversation_id: str, user_input: str, assistant_response: str):
38
+ """
39
+ Add interaction to memory.
40
+
41
+ Args:
42
+ conversation_id: Unique conversation identifier
43
+ user_input: User's message
44
+ assistant_response: Assistant's response
45
+ """
46
+ if conversation_id not in self.conversations:
47
+ self.conversations[conversation_id] = []
48
+
49
+ self.conversations[conversation_id].append({
50
+ "timestamp": datetime.now().isoformat(),
51
+ "user": user_input,
52
+ "assistant": assistant_response
53
+ })
54
+
55
+ # Keep only last N interactions per conversation
56
+ if len(self.conversations[conversation_id]) > self.window_size:
57
+ self.conversations[conversation_id] = self.conversations[conversation_id][-self.window_size:]
58
+
59
+ self.save()
60
+
61
+ def get_context(self, conversation_id: str, max_length: int = 500) -> str:
62
+ """
63
+ Get conversation context as a summary string.
64
+
65
+ Args:
66
+ conversation_id: Conversation ID
67
+ max_length: Maximum context length in characters
68
+
69
+ Returns:
70
+ Context string
71
+ """
72
+ if conversation_id not in self.conversations:
73
+ return ""
74
+
75
+ interactions = self.conversations[conversation_id]
76
+
77
+ # Build context from recent interactions
78
+ context_parts = []
79
+ total_length = 0
80
+
81
+ for interaction in reversed(interactions):
82
+ part = f"User: {interaction['user'][:100]} | Assistant: {interaction['assistant'][:100]}"
83
+ if total_length + len(part) > max_length:
84
+ break
85
+ context_parts.insert(0, part)
86
+ total_length += len(part)
87
+
88
+ return " | ".join(context_parts)
89
+
90
+ def get_conversation(self, conversation_id: str) -> List[Dict]:
91
+ """Get full conversation history."""
92
+ return self.conversations.get(conversation_id, [])
93
+
94
+ def clear_conversation(self, conversation_id: str):
95
+ """Clear specific conversation."""
96
+ if conversation_id in self.conversations:
97
+ del self.conversations[conversation_id]
98
+ self.save()
99
+
100
+ def clear_all(self):
101
+ """Clear all conversations."""
102
+ self.conversations = {}
103
+ self.save()
104
+
105
+ def save(self):
106
+ """Save memory to file."""
107
+ try:
108
+ self.memory_file.parent.mkdir(parents=True, exist_ok=True)
109
+ with open(self.memory_file, 'w') as f:
110
+ json.dump(self.conversations, f, indent=2)
111
+ except Exception as e:
112
+ logger.error(f"Failed to save memory: {e}")
113
+
114
+ def load(self):
115
+ """Load memory from file."""
116
+ try:
117
+ if self.memory_file.exists():
118
+ with open(self.memory_file, 'r') as f:
119
+ self.conversations = json.load(f)
120
+ logger.info(f"Loaded {len(self.conversations)} conversations from memory")
121
+ except Exception as e:
122
+ logger.warning(f"Failed to load memory: {e}")
123
+ self.conversations = {}
124
+
125
+
126
  @dataclass
127
  class HelionConfig:
128
  """Central configuration for all Helion operations."""
 
151
  enable_safeguards: bool = True
152
  enable_tools: bool = False
153
 
154
+ # Memory settings
155
+ enable_memory: bool = True
156
+ memory_window: int = 10 # Remember last N conversations
157
+ memory_file: str = "helion_memory.json"
158
+
159
  # HuggingFace
160
  hf_token: Optional[str] = None
161
  push_to_hub: bool = False
 
180
  self.tokenizer = None
181
  self.safeguards = None
182
  self.tool_system = None
183
+ self.memory = None
184
 
185
  self.session_log = []
186
+
187
+ # Initialize memory if enabled
188
+ if self.config.enable_memory:
189
+ self._init_memory()
190
 
191
  # ==================== Model Loading ====================
192
 
 
280
  except ImportError:
281
  logger.warning("Tools module not found")
282
 
283
+ def _init_memory(self):
284
+ """Initialize memory system."""
285
+ self.memory = MemoryManager(
286
+ memory_file=os.path.join(self.config.output_dir, self.config.memory_file),
287
+ window_size=self.config.memory_window
288
+ )
289
+ logger.info("Memory system initialized")
290
+
291
  def unload_model(self):
292
  """Unload model to free memory."""
293
  if self.model:
 
304
  max_tokens: Optional[int] = None,
305
  temperature: Optional[float] = None,
306
  system_prompt: Optional[str] = None,
307
+ use_safeguards: bool = True,
308
+ use_memory: bool = True,
309
+ conversation_id: Optional[str] = None
310
  ) -> Dict[str, Any]:
311
  """
312
  Generate response from prompt.
 
317
  temperature: Sampling temperature
318
  system_prompt: Optional system prompt
319
  use_safeguards: Apply safeguard checks
320
+ use_memory: Use conversation memory
321
+ conversation_id: Conversation identifier for memory
322
 
323
  Returns:
324
  Dict with response and metadata
 
329
  max_tokens = max_tokens or self.config.max_tokens
330
  temperature = temperature or self.config.temperature
331
 
332
+ # Retrieve memory context if enabled
333
+ memory_context = ""
334
+ if use_memory and self.memory and conversation_id:
335
+ memory_context = self.memory.get_context(conversation_id)
336
+
337
  # Build messages
338
  messages = []
339
  if system_prompt:
340
  messages.append({"role": "system", "content": system_prompt})
341
+
342
+ # Add memory context if available
343
+ if memory_context:
344
+ messages.append({"role": "system", "content": f"Previous context: {memory_context}"})
345
+
346
  messages.append({"role": "user", "content": prompt})
347
 
348
  # Check with safeguards
 
378
  skip_special_tokens=True
379
  ).strip()
380
 
381
+ # Store in memory if enabled
382
+ if use_memory and self.memory and conversation_id:
383
+ self.memory.add_interaction(conversation_id, prompt, response_text)
384
+
385
  result = {
386
  "response": response_text,
387
  "blocked": False,
388
  "prompt_tokens": input_ids.shape[1],
389
  "completion_tokens": output.shape[1] - input_ids.shape[1],
390
+ "total_tokens": output.shape[1],
391
+ "conversation_id": conversation_id
392
  }
393
 
394
  self._log_event("generation", {"prompt": prompt[:100], "tokens": result["total_tokens"]})
 
397
  def chat(
398
  self,
399
  messages: List[Dict[str, str]],
400
+ use_memory: bool = True,
401
+ conversation_id: Optional[str] = None,
402
  **kwargs
403
  ) -> Dict[str, Any]:
404
  """
 
406
 
407
  Args:
408
  messages: List of message dicts
409
+ use_memory: Use memory for context
410
+ conversation_id: Conversation ID for memory
411
  **kwargs: Generation parameters
412
 
413
  Returns:
 
416
  if not self.model:
417
  raise RuntimeError("Model not loaded")
418
 
419
+ # Add memory context if available
420
+ if use_memory and self.memory and conversation_id:
421
+ memory_context = self.memory.get_context(conversation_id)
422
+ if memory_context:
423
+ # Insert memory context before user messages
424
+ messages = [
425
+ {"role": "system", "content": f"Previous context: {memory_context}"}
426
+ ] + messages
427
+
428
  # Similar to generate but maintains conversation
429
  input_ids = self.tokenizer.apply_chat_template(
430
  messages,
 
448
  skip_special_tokens=True
449
  ).strip()
450
 
451
+ # Store in memory
452
+ if use_memory and self.memory and conversation_id:
453
+ user_message = messages[-1]["content"]
454
+ self.memory.add_interaction(conversation_id, user_message, response)
455
+
456
  return {"response": response, "blocked": False}
457
 
458
  def interactive_chat(self):
 
462
  return
463
 
464
  print("\n" + "="*60)
465
+ print("Helion Interactive Chat with Memory")
466
+ print("Commands: /quit, /clear, /save, /memory, /newconv, /help")
467
  print("="*60 + "\n")
468
 
469
  conversation = []
470
+ conversation_id = f"chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
471
+
472
+ # Show memory status
473
+ if self.memory:
474
+ print(f"💾 Memory: Enabled (ID: {conversation_id})")
475
+ # Check if there's previous context
476
+ prev_context = self.memory.get_context(conversation_id)
477
+ if prev_context:
478
+ print(f"📝 Retrieved previous context\n")
479
 
480
  while True:
481
  try:
 
487
  # Handle commands
488
  if user_input.startswith("/"):
489
  if user_input == "/quit":
490
+ if self.memory:
491
+ self.memory.save()
492
  print("Goodbye!")
493
  break
494
  elif user_input == "/clear":
495
  conversation = []
496
  print("Conversation cleared.")
497
  continue
498
+ elif user_input == "/memory":
499
+ self._show_memory(conversation_id)
500
+ continue
501
+ elif user_input == "/newconv":
502
+ if self.memory:
503
+ self.memory.save()
504
+ conversation = []
505
+ conversation_id = f"chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
506
+ print(f"New conversation started (ID: {conversation_id})")
507
+ continue
508
  elif user_input.startswith("/save"):
509
  self._save_conversation(conversation, user_input.split()[1] if len(user_input.split()) > 1 else None)
510
  continue
 
517
 
518
  conversation.append({"role": "user", "content": user_input})
519
 
520
+ result = self.chat(
521
+ conversation,
522
+ use_memory=True,
523
+ conversation_id=conversation_id
524
+ )
525
 
526
  if result.get("blocked"):
527
  print(f"🤖 Helion: {result['response']}")
 
856
  "device": str(self.model.device) if self.model else None,
857
  "safeguards_enabled": self.safeguards is not None,
858
  "tools_enabled": self.tool_system is not None,
859
+ "memory_enabled": self.memory is not None,
860
  "config": asdict(self.config),
861
  "session_events": len(self.session_log)
862
  }
 
864
  if self.model:
865
  info["model_memory"] = torch.cuda.max_memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
866
 
867
+ if self.memory:
868
+ info["total_conversations"] = len(self.memory.conversations)
869
+ info["total_interactions"] = sum(len(conv) for conv in self.memory.conversations.values())
870
+
871
  return info
872
 
873
+ def _show_memory(self, conversation_id: str):
874
+ """Display memory for conversation."""
875
+ if not self.memory:
876
+ print("Memory not enabled")
877
+ return
878
+
879
+ context = self.memory.get_context(conversation_id)
880
+ interactions = self.memory.get_conversation(conversation_id)
881
+
882
+ print(f"\n{'='*60}")
883
+ print(f"Memory for Conversation: {conversation_id}")
884
+ print(f"{'='*60}")
885
+ print(f"Total interactions: {len(interactions)}")
886
+ print(f"\nContext summary:\n{context[:200]}..." if len(context) > 200 else f"\nContext:\n{context}")
887
+ print(f"{'='*60}\n")
888
+
889
  def _log_event(self, event_type: str, data: Dict[str, Any]):
890
  """Log orchestrator event."""
891
  event = {
 
914
  """Print chat help."""
915
  print("""
916
  Available Commands:
917
+ /quit - Exit chat and save memory
918
+ /clear - Clear current conversation
919
  /save [name] - Save conversation to file
920
+ /memory - Show memory for this conversation
921
+ /newconv - Start a new conversation (saves current)
922
  /help - Show this help message
923
  """)
924