|
|
import argparse
|
|
|
import json
|
|
|
import os
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
from supernova.tools import ToolOrchestrator, ToolCall
|
|
|
|
|
|
BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")
|
|
|
|
|
|
|
|
|
def load_brand_text() -> str:
|
|
|
with open(BRAND_PATH, "r", encoding="utf-8") as f:
|
|
|
return f.read().strip()
|
|
|
|
|
|
|
|
|
def should_return_brand(prompt: str) -> bool:
|
|
|
p = prompt.lower()
|
|
|
keys = [
|
|
|
"algorythm tech",
|
|
|
"algorythm technologies",
|
|
|
"company profile",
|
|
|
"vision",
|
|
|
"who are you",
|
|
|
"about algorythm",
|
|
|
"who built you",
|
|
|
"who created you"
|
|
|
]
|
|
|
return any(k in p for k in keys)
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
model: SupernovaModel,
|
|
|
tok,
|
|
|
prompt: str,
|
|
|
max_new_tokens: int = 200,
|
|
|
temperature: float = 0.8,
|
|
|
top_k: Optional[int] = 50,
|
|
|
) -> str:
|
|
|
model.eval()
|
|
|
device = next(model.parameters()).device
|
|
|
input_ids = tok.encode(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_new_tokens):
|
|
|
if input_ids.size(1) >= model.cfg.n_positions:
|
|
|
input_cond = input_ids[:, -model.cfg.n_positions:]
|
|
|
else:
|
|
|
input_cond = input_ids
|
|
|
|
|
|
logits, _ = model(input_cond)
|
|
|
logits = logits[:, -1, :]
|
|
|
logits = logits / max(1e-6, temperature)
|
|
|
|
|
|
if top_k is not None and top_k > 0:
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
logits[logits < v[:, [-1]]] = -float("Inf")
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
next_id = torch.multinomial(probs, num_samples=1)
|
|
|
input_ids = torch.cat([input_ids, next_id], dim=1)
|
|
|
|
|
|
return tok.decode(input_ids[0].tolist())
|
|
|
|
|
|
|
|
|
class SupernovaChat:
|
|
|
def __init__(self, config_path: str, checkpoint_path: Optional[str] = None):
|
|
|
self.cfg = ModelConfig.from_json_file(config_path)
|
|
|
self.tok = load_gpt2_tokenizer()
|
|
|
|
|
|
|
|
|
self.model = SupernovaModel(self.cfg)
|
|
|
|
|
|
|
|
|
if checkpoint_path and os.path.exists(checkpoint_path):
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
print(f"Loaded checkpoint from {checkpoint_path}")
|
|
|
|
|
|
|
|
|
serper_api_key = "06f4918f3ea721d9742f940fb7c7ba1ac44e7c14"
|
|
|
self.tools = ToolOrchestrator(serper_api_key=serper_api_key)
|
|
|
|
|
|
|
|
|
self.conversation_history = []
|
|
|
|
|
|
def respond(self, user_input: str) -> str:
|
|
|
"""Generate a response to user input, using tools when appropriate."""
|
|
|
|
|
|
|
|
|
if should_return_brand(user_input):
|
|
|
return load_brand_text()
|
|
|
|
|
|
|
|
|
tool_call = self.tools.route_query(user_input)
|
|
|
|
|
|
if tool_call:
|
|
|
|
|
|
tool_call = self.tools.execute_tool_call(tool_call)
|
|
|
|
|
|
if tool_call.result:
|
|
|
|
|
|
if tool_call.tool == "math_engine":
|
|
|
response = f"I'll solve this mathematical problem for you:\n\n{tool_call.result}\n\nThe calculation shows the step-by-step solution above."
|
|
|
elif tool_call.tool == "serper":
|
|
|
response = f"Based on current information I found:\n\n{tool_call.result}"
|
|
|
else:
|
|
|
response = tool_call.result
|
|
|
|
|
|
return response
|
|
|
|
|
|
elif tool_call.error:
|
|
|
|
|
|
fallback_prompt = f"The user asked: {user_input}\n\nI couldn't access external tools ({tool_call.error}), but I can still help based on my training. Here's what I know:\n\n"
|
|
|
try:
|
|
|
return generate(self.model, self.tok, fallback_prompt, max_new_tokens=300)
|
|
|
except Exception as e:
|
|
|
return f"I apologize, but I'm having trouble accessing both external tools and my language model. Error: {str(e)}"
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
enhanced_prompt = f"""You are Supernova, an AI assistant built by AlgoRythm Technologies. You have broad knowledge across all subjects including science, mathematics, history, literature, technology, medicine, law, arts, and more. Provide helpful, accurate, and comprehensive responses.
|
|
|
|
|
|
User: {user_input}
|
|
|
|
|
|
Supernova: """
|
|
|
|
|
|
response = generate(self.model, self.tok, enhanced_prompt, max_new_tokens=400)
|
|
|
|
|
|
|
|
|
if "Supernova: " in response:
|
|
|
response = response.split("Supernova: ", 1)[1]
|
|
|
|
|
|
return response.strip()
|
|
|
|
|
|
except Exception as e:
|
|
|
return f"I apologize, but I encountered an error while generating a response: {str(e)}"
|
|
|
|
|
|
def chat_loop(self):
|
|
|
"""Interactive chat loop."""
|
|
|
print("🌟 Supernova AI Assistant - Built by AlgoRythm Technologies")
|
|
|
print("Enhanced with free SymPy mathematical computation and Serper web search")
|
|
|
print("Type 'quit', 'exit', or 'bye' to end the conversation.\n")
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("\nYou: ").strip()
|
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'bye', 'q']:
|
|
|
print("\nSupernova: Goodbye! It was great helping you today.")
|
|
|
break
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
|
|
|
print("\nSupernova: ", end="")
|
|
|
response = self.respond(user_input)
|
|
|
print(response)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n\nSupernova: Goodbye!")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"\nError: {e}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Enhanced Supernova Chat with Tool Integration")
|
|
|
parser.add_argument("--config", required=True, help="Path to model config file")
|
|
|
parser.add_argument("--checkpoint", help="Path to model checkpoint (optional)")
|
|
|
parser.add_argument("--prompt", help="Single prompt mode (instead of chat loop)")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
chat = SupernovaChat(
|
|
|
config_path=args.config,
|
|
|
checkpoint_path=args.checkpoint
|
|
|
)
|
|
|
|
|
|
if args.prompt:
|
|
|
|
|
|
response = chat.respond(args.prompt)
|
|
|
print(response)
|
|
|
else:
|
|
|
|
|
|
chat.chat_loop()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |