import argparse import json import os from typing import Optional import torch from supernova.config import ModelConfig from supernova.model import SupernovaModel from supernova.tokenizer import load_gpt2_tokenizer from supernova.tools import ToolOrchestrator, ToolCall BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt") def load_brand_text() -> str: with open(BRAND_PATH, "r", encoding="utf-8") as f: return f.read().strip() def should_return_brand(prompt: str) -> bool: p = prompt.lower() keys = [ "algorythm tech", "algorythm technologies", "company profile", "vision", "who are you", "about algorythm", "who built you", "who created you" ] return any(k in p for k in keys) def generate( model: SupernovaModel, tok, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8, top_k: Optional[int] = 50, ) -> str: model.eval() device = next(model.parameters()).device input_ids = tok.encode(prompt, return_tensors="pt").to(device) with torch.no_grad(): for _ in range(max_new_tokens): if input_ids.size(1) >= model.cfg.n_positions: input_cond = input_ids[:, -model.cfg.n_positions:] else: input_cond = input_ids logits, _ = model(input_cond) logits = logits[:, -1, :] logits = logits / max(1e-6, temperature) if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_id], dim=1) return tok.decode(input_ids[0].tolist()) class SupernovaChat: def __init__(self, config_path: str, checkpoint_path: Optional[str] = None): self.cfg = ModelConfig.from_json_file(config_path) self.tok = load_gpt2_tokenizer() # Initialize model self.model = SupernovaModel(self.cfg) # Load checkpoint if provided if checkpoint_path and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') self.model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded checkpoint from {checkpoint_path}") # Initialize tool orchestrator with hardcoded Serper API key serper_api_key = "06f4918f3ea721d9742f940fb7c7ba1ac44e7c14" self.tools = ToolOrchestrator(serper_api_key=serper_api_key) # Track conversation for context self.conversation_history = [] def respond(self, user_input: str) -> str: """Generate a response to user input, using tools when appropriate.""" # Check for brand queries first if should_return_brand(user_input): return load_brand_text() # Check if we should use tools tool_call = self.tools.route_query(user_input) if tool_call: # Execute the tool call tool_call = self.tools.execute_tool_call(tool_call) if tool_call.result: # Format the response with tool results if tool_call.tool == "math_engine": response = f"I'll solve this mathematical problem for you:\n\n{tool_call.result}\n\nThe calculation shows the step-by-step solution above." elif tool_call.tool == "serper": response = f"Based on current information I found:\n\n{tool_call.result}" else: response = tool_call.result return response elif tool_call.error: # Tool failed, fall back to model generation with error context fallback_prompt = f"The user asked: {user_input}\n\nI couldn't access external tools ({tool_call.error}), but I can still help based on my training. Here's what I know:\n\n" try: return generate(self.model, self.tok, fallback_prompt, max_new_tokens=300) except Exception as e: return f"I apologize, but I'm having trouble accessing both external tools and my language model. Error: {str(e)}" # No tools needed, use direct generation try: # Create a comprehensive prompt that encourages broad knowledge use enhanced_prompt = f"""You are Supernova, an AI assistant built by AlgoRythm Technologies. You have broad knowledge across all subjects including science, mathematics, history, literature, technology, medicine, law, arts, and more. Provide helpful, accurate, and comprehensive responses. User: {user_input} Supernova: """ response = generate(self.model, self.tok, enhanced_prompt, max_new_tokens=400) # Extract just the Supernova response part if "Supernova: " in response: response = response.split("Supernova: ", 1)[1] return response.strip() except Exception as e: return f"I apologize, but I encountered an error while generating a response: {str(e)}" def chat_loop(self): """Interactive chat loop.""" print("🌟 Supernova AI Assistant - Built by AlgoRythm Technologies") print("Enhanced with free SymPy mathematical computation and Serper web search") print("Type 'quit', 'exit', or 'bye' to end the conversation.\n") while True: try: user_input = input("\nYou: ").strip() if user_input.lower() in ['quit', 'exit', 'bye', 'q']: print("\nSupernova: Goodbye! It was great helping you today.") break if not user_input: continue print("\nSupernova: ", end="") response = self.respond(user_input) print(response) except KeyboardInterrupt: print("\n\nSupernova: Goodbye!") break except Exception as e: print(f"\nError: {e}") def main(): parser = argparse.ArgumentParser(description="Enhanced Supernova Chat with Tool Integration") parser.add_argument("--config", required=True, help="Path to model config file") parser.add_argument("--checkpoint", help="Path to model checkpoint (optional)") parser.add_argument("--prompt", help="Single prompt mode (instead of chat loop)") args = parser.parse_args() # Initialize chat system chat = SupernovaChat( config_path=args.config, checkpoint_path=args.checkpoint ) if args.prompt: # Single prompt mode response = chat.respond(args.prompt) print(response) else: # Interactive chat loop chat.chat_loop() if __name__ == "__main__": main()