File size: 7,494 Bytes
8174855 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import argparse
import json
import os
from typing import Optional
import torch
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.tools import ToolOrchestrator, ToolCall
BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")
def load_brand_text() -> str:
with open(BRAND_PATH, "r", encoding="utf-8") as f:
return f.read().strip()
def should_return_brand(prompt: str) -> bool:
p = prompt.lower()
keys = [
"algorythm tech",
"algorythm technologies",
"company profile",
"vision",
"who are you",
"about algorythm",
"who built you",
"who created you"
]
return any(k in p for k in keys)
def generate(
model: SupernovaModel,
tok,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_k: Optional[int] = 50,
) -> str:
model.eval()
device = next(model.parameters()).device
input_ids = tok.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
if input_ids.size(1) >= model.cfg.n_positions:
input_cond = input_ids[:, -model.cfg.n_positions:]
else:
input_cond = input_ids
logits, _ = model(input_cond)
logits = logits[:, -1, :]
logits = logits / max(1e-6, temperature)
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return tok.decode(input_ids[0].tolist())
class SupernovaChat:
def __init__(self, config_path: str, checkpoint_path: Optional[str] = None):
self.cfg = ModelConfig.from_json_file(config_path)
self.tok = load_gpt2_tokenizer()
# Initialize model
self.model = SupernovaModel(self.cfg)
# Load checkpoint if provided
if checkpoint_path and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from {checkpoint_path}")
# Initialize tool orchestrator with hardcoded Serper API key
serper_api_key = "06f4918f3ea721d9742f940fb7c7ba1ac44e7c14"
self.tools = ToolOrchestrator(serper_api_key=serper_api_key)
# Track conversation for context
self.conversation_history = []
def respond(self, user_input: str) -> str:
"""Generate a response to user input, using tools when appropriate."""
# Check for brand queries first
if should_return_brand(user_input):
return load_brand_text()
# Check if we should use tools
tool_call = self.tools.route_query(user_input)
if tool_call:
# Execute the tool call
tool_call = self.tools.execute_tool_call(tool_call)
if tool_call.result:
# Format the response with tool results
if tool_call.tool == "math_engine":
response = f"I'll solve this mathematical problem for you:\n\n{tool_call.result}\n\nThe calculation shows the step-by-step solution above."
elif tool_call.tool == "serper":
response = f"Based on current information I found:\n\n{tool_call.result}"
else:
response = tool_call.result
return response
elif tool_call.error:
# Tool failed, fall back to model generation with error context
fallback_prompt = f"The user asked: {user_input}\n\nI couldn't access external tools ({tool_call.error}), but I can still help based on my training. Here's what I know:\n\n"
try:
return generate(self.model, self.tok, fallback_prompt, max_new_tokens=300)
except Exception as e:
return f"I apologize, but I'm having trouble accessing both external tools and my language model. Error: {str(e)}"
# No tools needed, use direct generation
try:
# Create a comprehensive prompt that encourages broad knowledge use
enhanced_prompt = f"""You are Supernova, an AI assistant built by AlgoRythm Technologies. You have broad knowledge across all subjects including science, mathematics, history, literature, technology, medicine, law, arts, and more. Provide helpful, accurate, and comprehensive responses.
User: {user_input}
Supernova: """
response = generate(self.model, self.tok, enhanced_prompt, max_new_tokens=400)
# Extract just the Supernova response part
if "Supernova: " in response:
response = response.split("Supernova: ", 1)[1]
return response.strip()
except Exception as e:
return f"I apologize, but I encountered an error while generating a response: {str(e)}"
def chat_loop(self):
"""Interactive chat loop."""
print("🌟 Supernova AI Assistant - Built by AlgoRythm Technologies")
print("Enhanced with free SymPy mathematical computation and Serper web search")
print("Type 'quit', 'exit', or 'bye' to end the conversation.\n")
while True:
try:
user_input = input("\nYou: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye', 'q']:
print("\nSupernova: Goodbye! It was great helping you today.")
break
if not user_input:
continue
print("\nSupernova: ", end="")
response = self.respond(user_input)
print(response)
except KeyboardInterrupt:
print("\n\nSupernova: Goodbye!")
break
except Exception as e:
print(f"\nError: {e}")
def main():
parser = argparse.ArgumentParser(description="Enhanced Supernova Chat with Tool Integration")
parser.add_argument("--config", required=True, help="Path to model config file")
parser.add_argument("--checkpoint", help="Path to model checkpoint (optional)")
parser.add_argument("--prompt", help="Single prompt mode (instead of chat loop)")
args = parser.parse_args()
# Initialize chat system
chat = SupernovaChat(
config_path=args.config,
checkpoint_path=args.checkpoint
)
if args.prompt:
# Single prompt mode
response = chat.respond(args.prompt)
print(response)
else:
# Interactive chat loop
chat.chat_loop()
if __name__ == "__main__":
main() |