File size: 7,494 Bytes
8174855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import argparse
import json
import os
from typing import Optional

import torch

from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.tools import ToolOrchestrator, ToolCall

BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")


def load_brand_text() -> str:
    with open(BRAND_PATH, "r", encoding="utf-8") as f:
        return f.read().strip()


def should_return_brand(prompt: str) -> bool:
    p = prompt.lower()
    keys = [
        "algorythm tech",
        "algorythm technologies",
        "company profile",
        "vision",
        "who are you",
        "about algorythm",
        "who built you",
        "who created you"
    ]
    return any(k in p for k in keys)


def generate(

    model: SupernovaModel,

    tok,

    prompt: str,

    max_new_tokens: int = 200,

    temperature: float = 0.8,

    top_k: Optional[int] = 50,

) -> str:
    model.eval()
    device = next(model.parameters()).device
    input_ids = tok.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if input_ids.size(1) >= model.cfg.n_positions:
                input_cond = input_ids[:, -model.cfg.n_positions:]
            else:
                input_cond = input_ids
            
            logits, _ = model(input_cond)
            logits = logits[:, -1, :]
            logits = logits / max(1e-6, temperature)
            
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_id], dim=1)
    
    return tok.decode(input_ids[0].tolist())


class SupernovaChat:
    def __init__(self, config_path: str, checkpoint_path: Optional[str] = None):
        self.cfg = ModelConfig.from_json_file(config_path)
        self.tok = load_gpt2_tokenizer()
        
        # Initialize model
        self.model = SupernovaModel(self.cfg)
        
        # Load checkpoint if provided
        if checkpoint_path and os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded checkpoint from {checkpoint_path}")
        
        # Initialize tool orchestrator with hardcoded Serper API key
        serper_api_key = "06f4918f3ea721d9742f940fb7c7ba1ac44e7c14"
        self.tools = ToolOrchestrator(serper_api_key=serper_api_key)
        
        # Track conversation for context
        self.conversation_history = []
    
    def respond(self, user_input: str) -> str:
        """Generate a response to user input, using tools when appropriate."""
        
        # Check for brand queries first
        if should_return_brand(user_input):
            return load_brand_text()
        
        # Check if we should use tools
        tool_call = self.tools.route_query(user_input)
        
        if tool_call:
            # Execute the tool call
            tool_call = self.tools.execute_tool_call(tool_call)
            
            if tool_call.result:
                # Format the response with tool results
                if tool_call.tool == "math_engine":
                    response = f"I'll solve this mathematical problem for you:\n\n{tool_call.result}\n\nThe calculation shows the step-by-step solution above."
                elif tool_call.tool == "serper":
                    response = f"Based on current information I found:\n\n{tool_call.result}"
                else:
                    response = tool_call.result
                
                return response
            
            elif tool_call.error:
                # Tool failed, fall back to model generation with error context
                fallback_prompt = f"The user asked: {user_input}\n\nI couldn't access external tools ({tool_call.error}), but I can still help based on my training. Here's what I know:\n\n"
                try:
                    return generate(self.model, self.tok, fallback_prompt, max_new_tokens=300)
                except Exception as e:
                    return f"I apologize, but I'm having trouble accessing both external tools and my language model. Error: {str(e)}"
        
        # No tools needed, use direct generation
        try:
            # Create a comprehensive prompt that encourages broad knowledge use
            enhanced_prompt = f"""You are Supernova, an AI assistant built by AlgoRythm Technologies. You have broad knowledge across all subjects including science, mathematics, history, literature, technology, medicine, law, arts, and more. Provide helpful, accurate, and comprehensive responses.



User: {user_input}



Supernova: """
            
            response = generate(self.model, self.tok, enhanced_prompt, max_new_tokens=400)
            
            # Extract just the Supernova response part
            if "Supernova: " in response:
                response = response.split("Supernova: ", 1)[1]
            
            return response.strip()
            
        except Exception as e:
            return f"I apologize, but I encountered an error while generating a response: {str(e)}"
    
    def chat_loop(self):
        """Interactive chat loop."""
        print("🌟 Supernova AI Assistant - Built by AlgoRythm Technologies")
        print("Enhanced with free SymPy mathematical computation and Serper web search")
        print("Type 'quit', 'exit', or 'bye' to end the conversation.\n")
        
        while True:
            try:
                user_input = input("\nYou: ").strip()
                
                if user_input.lower() in ['quit', 'exit', 'bye', 'q']:
                    print("\nSupernova: Goodbye! It was great helping you today.")
                    break
                
                if not user_input:
                    continue
                
                print("\nSupernova: ", end="")
                response = self.respond(user_input)
                print(response)
                
            except KeyboardInterrupt:
                print("\n\nSupernova: Goodbye!")
                break
            except Exception as e:
                print(f"\nError: {e}")


def main():
    parser = argparse.ArgumentParser(description="Enhanced Supernova Chat with Tool Integration")
    parser.add_argument("--config", required=True, help="Path to model config file")
    parser.add_argument("--checkpoint", help="Path to model checkpoint (optional)")
    parser.add_argument("--prompt", help="Single prompt mode (instead of chat loop)")
    
    args = parser.parse_args()
    
    # Initialize chat system
    chat = SupernovaChat(
        config_path=args.config,
        checkpoint_path=args.checkpoint
    )
    
    if args.prompt:
        # Single prompt mode
        response = chat.respond(args.prompt)
        print(response)
    else:
        # Interactive chat loop
        chat.chat_loop()


if __name__ == "__main__":
    main()