File size: 2,526 Bytes
8174855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import json
import os
from typing import Optional

import torch

from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer

BRAND_PATH = os.path.join(os.path.dirname(__file__), "branding", "ALGORHYTHM_TECH_PROFILE.txt")


def load_brand_text() -> str:
    with open(BRAND_PATH, "r", encoding="utf-8") as f:
        return f.read().strip()


def should_return_brand(prompt: str) -> bool:
    p = prompt.lower()
    keys = [
        "algorythm tech",
        "algorythm technologies",
        "company profile",
        "vision",
        "who are you",
        "about algorythm",
    ]
    return any(k in p for k in keys)


def generate(

    model: SupernovaModel,

    tok,

    prompt: str,

    max_new_tokens: int = 200,

    temperature: float = 0.8,

    top_k: Optional[int] = 50,

) -> str:
    model.eval()
    device = next(model.parameters()).device
    input_ids = tok.encode(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if input_ids.size(1) >= model.cfg.n_positions:
                input_cond = input_ids[:, -model.cfg.n_positions :]
            else:
                input_cond = input_ids
            logits, _ = model(input_cond)
            logits = logits[:, -1, :]
            logits = logits / max(1e-6, temperature)
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_id], dim=1)
    return tok.decode(input_ids[0].tolist())


def main(config_path: str, prompt: str):
    cfg = ModelConfig.from_json_file(config_path)
    tok = load_gpt2_tokenizer()

    # Construct model (random weights unless you load a checkpoint)
    model = SupernovaModel(cfg)

    if should_return_brand(prompt):
        print(load_brand_text())
        return

    # Otherwise, generate (will be gibberish until trained)
    out = generate(model, tok, prompt)
    print(out)


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--prompt", required=True)
    args = ap.parse_args()
    main(args.config, args.prompt)