Supernova25million / chat_enhanced.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
import argparse
import json
import os
from typing import Optional
import torch
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.tools import ToolOrchestrator, ToolCall
BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")
def load_brand_text() -> str:
with open(BRAND_PATH, "r", encoding="utf-8") as f:
return f.read().strip()
def should_return_brand(prompt: str) -> bool:
p = prompt.lower()
keys = [
"algorythm tech",
"algorythm technologies",
"company profile",
"vision",
"who are you",
"about algorythm",
"who built you",
"who created you"
]
return any(k in p for k in keys)
def generate(
model: SupernovaModel,
tok,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_k: Optional[int] = 50,
) -> str:
model.eval()
device = next(model.parameters()).device
input_ids = tok.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
if input_ids.size(1) >= model.cfg.n_positions:
input_cond = input_ids[:, -model.cfg.n_positions:]
else:
input_cond = input_ids
logits, _ = model(input_cond)
logits = logits[:, -1, :]
logits = logits / max(1e-6, temperature)
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return tok.decode(input_ids[0].tolist())
class SupernovaChat:
def __init__(self, config_path: str, checkpoint_path: Optional[str] = None):
self.cfg = ModelConfig.from_json_file(config_path)
self.tok = load_gpt2_tokenizer()
# Initialize model
self.model = SupernovaModel(self.cfg)
# Load checkpoint if provided
if checkpoint_path and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from {checkpoint_path}")
# Initialize tool orchestrator with hardcoded Serper API key
serper_api_key = "06f4918f3ea721d9742f940fb7c7ba1ac44e7c14"
self.tools = ToolOrchestrator(serper_api_key=serper_api_key)
# Track conversation for context
self.conversation_history = []
def respond(self, user_input: str) -> str:
"""Generate a response to user input, using tools when appropriate."""
# Check for brand queries first
if should_return_brand(user_input):
return load_brand_text()
# Check if we should use tools
tool_call = self.tools.route_query(user_input)
if tool_call:
# Execute the tool call
tool_call = self.tools.execute_tool_call(tool_call)
if tool_call.result:
# Format the response with tool results
if tool_call.tool == "math_engine":
response = f"I'll solve this mathematical problem for you:\n\n{tool_call.result}\n\nThe calculation shows the step-by-step solution above."
elif tool_call.tool == "serper":
response = f"Based on current information I found:\n\n{tool_call.result}"
else:
response = tool_call.result
return response
elif tool_call.error:
# Tool failed, fall back to model generation with error context
fallback_prompt = f"The user asked: {user_input}\n\nI couldn't access external tools ({tool_call.error}), but I can still help based on my training. Here's what I know:\n\n"
try:
return generate(self.model, self.tok, fallback_prompt, max_new_tokens=300)
except Exception as e:
return f"I apologize, but I'm having trouble accessing both external tools and my language model. Error: {str(e)}"
# No tools needed, use direct generation
try:
# Create a comprehensive prompt that encourages broad knowledge use
enhanced_prompt = f"""You are Supernova, an AI assistant built by AlgoRythm Technologies. You have broad knowledge across all subjects including science, mathematics, history, literature, technology, medicine, law, arts, and more. Provide helpful, accurate, and comprehensive responses.
User: {user_input}
Supernova: """
response = generate(self.model, self.tok, enhanced_prompt, max_new_tokens=400)
# Extract just the Supernova response part
if "Supernova: " in response:
response = response.split("Supernova: ", 1)[1]
return response.strip()
except Exception as e:
return f"I apologize, but I encountered an error while generating a response: {str(e)}"
def chat_loop(self):
"""Interactive chat loop."""
print("🌟 Supernova AI Assistant - Built by AlgoRythm Technologies")
print("Enhanced with free SymPy mathematical computation and Serper web search")
print("Type 'quit', 'exit', or 'bye' to end the conversation.\n")
while True:
try:
user_input = input("\nYou: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye', 'q']:
print("\nSupernova: Goodbye! It was great helping you today.")
break
if not user_input:
continue
print("\nSupernova: ", end="")
response = self.respond(user_input)
print(response)
except KeyboardInterrupt:
print("\n\nSupernova: Goodbye!")
break
except Exception as e:
print(f"\nError: {e}")
def main():
parser = argparse.ArgumentParser(description="Enhanced Supernova Chat with Tool Integration")
parser.add_argument("--config", required=True, help="Path to model config file")
parser.add_argument("--checkpoint", help="Path to model checkpoint (optional)")
parser.add_argument("--prompt", help="Single prompt mode (instead of chat loop)")
args = parser.parse_args()
# Initialize chat system
chat = SupernovaChat(
config_path=args.config,
checkpoint_path=args.checkpoint
)
if args.prompt:
# Single prompt mode
response = chat.respond(args.prompt)
print(response)
else:
# Interactive chat loop
chat.chat_loop()
if __name__ == "__main__":
main()