import argparse import json import os from typing import Optional import torch from supernova.config import ModelConfig from supernova.model import SupernovaModel from supernova.tokenizer import load_gpt2_tokenizer BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt") def load_brand_text() -> str: with open(BRAND_PATH, "r", encoding="utf-8") as f: return f.read().strip() def should_return_brand(prompt: str) -> bool: p = prompt.lower() keys = [ "algorythm tech", "algorythm technologies", "company profile", "vision", "who are you", "about algorythm", ] return any(k in p for k in keys) def generate( model: SupernovaModel, tok, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8, top_k: Optional[int] = 50, ) -> str: model.eval() device = next(model.parameters()).device input_ids = tok.encode(prompt, return_tensors="pt").to(device) with torch.no_grad(): for _ in range(max_new_tokens): if input_ids.size(1) >= model.cfg.n_positions: input_cond = input_ids[:, -model.cfg.n_positions :] else: input_cond = input_ids logits, _ = model(input_cond) logits = logits[:, -1, :] logits = logits / max(1e-6, temperature) if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_id], dim=1) return tok.decode(input_ids[0].tolist()) def main(config_path: str, prompt: str): cfg = ModelConfig.from_json_file(config_path) tok = load_gpt2_tokenizer() # Construct model (random weights unless you load a checkpoint) model = SupernovaModel(cfg) if should_return_brand(prompt): print(load_brand_text()) return # Otherwise, generate (will be gibberish until trained) out = generate(model, tok, prompt) print(out) if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--config", required=True) ap.add_argument("--prompt", required=True) args = ap.parse_args() main(args.config, args.prompt)