File size: 4,922 Bytes
681909f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import re
import sys
from pathlib import Path

import sentencepiece as spm
import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from sovyn import SovynConfig, SovynForCausalLM
from sovyn.formatting import format_prompt


def clean_answer(text: str) -> str:
    answer = text.split("<assistant>")[-1].replace("<eos>", "").strip()
    for marker in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]:
        answer = answer.split(marker)[0].strip()
    sentences = re.findall(r".+?[.!?。](?=\s|$)", answer)
    if sentences:
        first = sentences[0].strip()
        generic_first = first in {
            "좋아.",
            "응.",
            "알겠어.",
            "가능해.",
            "그럴 수 있어.",
            "불안할 수 있어.",
            "많이 지쳤겠다.",
        }
        if generic_first and len(sentences) > 1:
            answer = " ".join(item.strip() for item in sentences[:2])
        else:
            answer = first
    return answer


def score_answer(user: str, answer: str) -> int:
    text = user.lower()
    score = 0
    groups = [
        (["피곤", "지쳤", "힘들", "기운", "복잡"], ["지쳤", "쉬", "힘들", "숨", "들어줄", "괜찮"]),
        (["누구", "정체", "sovyn"], ["sovyn", "모델", "ai", "학습", "기억"]),
        (["짧", "간단", "핵심", "요약"], ["핵심", "짧", "결론", "간단"]),
        (
            ["1b", "120m", "300m", "모델", "학습", "키울", "저장공간", "저장 공간", "체크포인트", "부족"],
            ["모델", "학습", "120m", "300m", "1b", "키우", "실험", "체크포인트", "저장", "공간"],
        ),
        (["먹", "점심", "저녁", "음식"], ["김밥", "국밥", "덮밥", "샌드위치", "먹"]),
        (["공부", "집중", "미루", "루틴"], ["10분", "타이머", "문제", "시작", "공부"]),
        (["불안", "걱정", "초조"], ["불안", "걱정", "숨", "괜찮", "나눠"]),
    ]
    for user_words, answer_words in groups:
        if any(word in text for word in user_words):
            score += sum(2 for word in answer_words if word.lower() in answer.lower())
            break
    if answer in {"안녕.", "좋아.", "응.", "좋은 생각이야.", "좋은 아침이야."}:
        score -= 3
    if len(answer) < 4:
        score -= 2
    if 8 <= len(answer) <= 80:
        score += 1
    return score


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--tokenizer", required=True)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--max-new-tokens", type=int, default=96)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top-k", type=int, default=50)
    parser.add_argument("--best-of", type=int, default=4)
    args = parser.parse_args()

    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        device = "cpu"

    sp = spm.SentencePieceProcessor(model_file=args.tokenizer)
    ckpt = torch.load(args.checkpoint, map_location="cpu")
    model_cfg = ckpt["config"]["model"]
    if model_cfg["vocab_size"] != sp.get_piece_size():
        raise ValueError(
            f"Checkpoint vocab_size={model_cfg['vocab_size']} but tokenizer has "
            f"{sp.get_piece_size()} pieces"
        )
    model = SovynForCausalLM(SovynConfig(**model_cfg))
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.eval()

    eos_id = sp.piece_to_id("<eos>")
    stop_ids = [
        sp.piece_to_id(piece)
        for piece in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]
        if sp.piece_to_id(piece) >= 0
    ]
    suppress_ids = [
        idx
        for idx in [sp.piece_to_id("<pad>"), sp.piece_to_id("<unk>"), sp.piece_to_id("<bos>")]
        if idx >= 0
    ]
    print("SOVYN chat. Type 'exit' to quit.")

    while True:
        user = input("you> ").strip()
        if user.lower() in {"exit", "quit"}:
            break
        prompt = format_prompt(user)
        ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long, device=device)
        candidates = []
        runs = max(1, args.best_of if args.temperature > 0 else 1)
        for _ in range(runs):
            out = model.generate(
                ids,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                top_k=args.top_k,
                eos_id=eos_id,
                stop_ids=stop_ids,
                suppress_ids=suppress_ids,
            )
            answer = clean_answer(sp.decode(out[0].tolist()))
            candidates.append(answer)
        answer = max(candidates, key=lambda item: score_answer(user, item))
        print(f"sovyn> {answer}")


if __name__ == "__main__":
    main()