algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
import argparse
import json
import os
from typing import Optional
import torch
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")
def load_brand_text() -> str:
with open(BRAND_PATH, "r", encoding="utf-8") as f:
return f.read().strip()
def should_return_brand(prompt: str) -> bool:
p = prompt.lower()
keys = [
"algorythm tech",
"algorythm technologies",
"company profile",
"vision",
"who are you",
"about algorythm",
]
return any(k in p for k in keys)
def generate(
model: SupernovaModel,
tok,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_k: Optional[int] = 50,
) -> str:
model.eval()
device = next(model.parameters()).device
input_ids = tok.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
if input_ids.size(1) >= model.cfg.n_positions:
input_cond = input_ids[:, -model.cfg.n_positions :]
else:
input_cond = input_ids
logits, _ = model(input_cond)
logits = logits[:, -1, :]
logits = logits / max(1e-6, temperature)
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return tok.decode(input_ids[0].tolist())
def main(config_path: str, prompt: str):
cfg = ModelConfig.from_json_file(config_path)
tok = load_gpt2_tokenizer()
# Construct model (random weights unless you load a checkpoint)
model = SupernovaModel(cfg)
if should_return_brand(prompt):
print(load_brand_text())
return
# Otherwise, generate (will be gibberish until trained)
out = generate(model, tok, prompt)
print(out)
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True)
ap.add_argument("--prompt", required=True)
args = ap.parse_args()
main(args.config, args.prompt)