|
|
import argparse
|
|
|
import json
|
|
|
import os
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
|
|
|
BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")
|
|
|
|
|
|
|
|
|
def load_brand_text() -> str:
|
|
|
with open(BRAND_PATH, "r", encoding="utf-8") as f:
|
|
|
return f.read().strip()
|
|
|
|
|
|
|
|
|
def should_return_brand(prompt: str) -> bool:
|
|
|
p = prompt.lower()
|
|
|
keys = [
|
|
|
"algorythm tech",
|
|
|
"algorythm technologies",
|
|
|
"company profile",
|
|
|
"vision",
|
|
|
"who are you",
|
|
|
"about algorythm",
|
|
|
]
|
|
|
return any(k in p for k in keys)
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
model: SupernovaModel,
|
|
|
tok,
|
|
|
prompt: str,
|
|
|
max_new_tokens: int = 200,
|
|
|
temperature: float = 0.8,
|
|
|
top_k: Optional[int] = 50,
|
|
|
) -> str:
|
|
|
model.eval()
|
|
|
device = next(model.parameters()).device
|
|
|
input_ids = tok.encode(prompt, return_tensors="pt").to(device)
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_new_tokens):
|
|
|
if input_ids.size(1) >= model.cfg.n_positions:
|
|
|
input_cond = input_ids[:, -model.cfg.n_positions :]
|
|
|
else:
|
|
|
input_cond = input_ids
|
|
|
logits, _ = model(input_cond)
|
|
|
logits = logits[:, -1, :]
|
|
|
logits = logits / max(1e-6, temperature)
|
|
|
if top_k is not None and top_k > 0:
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
logits[logits < v[:, [-1]]] = -float("Inf")
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
next_id = torch.multinomial(probs, num_samples=1)
|
|
|
input_ids = torch.cat([input_ids, next_id], dim=1)
|
|
|
return tok.decode(input_ids[0].tolist())
|
|
|
|
|
|
|
|
|
def main(config_path: str, prompt: str):
|
|
|
cfg = ModelConfig.from_json_file(config_path)
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
|
|
|
|
|
|
model = SupernovaModel(cfg)
|
|
|
|
|
|
if should_return_brand(prompt):
|
|
|
print(load_brand_text())
|
|
|
return
|
|
|
|
|
|
|
|
|
out = generate(model, tok, prompt)
|
|
|
print(out)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
ap = argparse.ArgumentParser()
|
|
|
ap.add_argument("--config", required=True)
|
|
|
ap.add_argument("--prompt", required=True)
|
|
|
args = ap.parse_args()
|
|
|
main(args.config, args.prompt)
|
|
|
|