File size: 4,922 Bytes
681909f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import argparse
import re
import sys
from pathlib import Path
import sentencepiece as spm
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from sovyn import SovynConfig, SovynForCausalLM
from sovyn.formatting import format_prompt
def clean_answer(text: str) -> str:
answer = text.split("<assistant>")[-1].replace("<eos>", "").strip()
for marker in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]:
answer = answer.split(marker)[0].strip()
sentences = re.findall(r".+?[.!?。](?=\s|$)", answer)
if sentences:
first = sentences[0].strip()
generic_first = first in {
"좋아.",
"응.",
"알겠어.",
"가능해.",
"그럴 수 있어.",
"불안할 수 있어.",
"많이 지쳤겠다.",
}
if generic_first and len(sentences) > 1:
answer = " ".join(item.strip() for item in sentences[:2])
else:
answer = first
return answer
def score_answer(user: str, answer: str) -> int:
text = user.lower()
score = 0
groups = [
(["피곤", "지쳤", "힘들", "기운", "복잡"], ["지쳤", "쉬", "힘들", "숨", "들어줄", "괜찮"]),
(["누구", "정체", "sovyn"], ["sovyn", "모델", "ai", "학습", "기억"]),
(["짧", "간단", "핵심", "요약"], ["핵심", "짧", "결론", "간단"]),
(
["1b", "120m", "300m", "모델", "학습", "키울", "저장공간", "저장 공간", "체크포인트", "부족"],
["모델", "학습", "120m", "300m", "1b", "키우", "실험", "체크포인트", "저장", "공간"],
),
(["먹", "점심", "저녁", "음식"], ["김밥", "국밥", "덮밥", "샌드위치", "먹"]),
(["공부", "집중", "미루", "루틴"], ["10분", "타이머", "문제", "시작", "공부"]),
(["불안", "걱정", "초조"], ["불안", "걱정", "숨", "괜찮", "나눠"]),
]
for user_words, answer_words in groups:
if any(word in text for word in user_words):
score += sum(2 for word in answer_words if word.lower() in answer.lower())
break
if answer in {"안녕.", "좋아.", "응.", "좋은 생각이야.", "좋은 아침이야."}:
score -= 3
if len(answer) < 4:
score -= 2
if 8 <= len(answer) <= 80:
score += 1
return score
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--tokenizer", required=True)
parser.add_argument("--device", default="cuda")
parser.add_argument("--max-new-tokens", type=int, default=96)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--best-of", type=int, default=4)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
sp = spm.SentencePieceProcessor(model_file=args.tokenizer)
ckpt = torch.load(args.checkpoint, map_location="cpu")
model_cfg = ckpt["config"]["model"]
if model_cfg["vocab_size"] != sp.get_piece_size():
raise ValueError(
f"Checkpoint vocab_size={model_cfg['vocab_size']} but tokenizer has "
f"{sp.get_piece_size()} pieces"
)
model = SovynForCausalLM(SovynConfig(**model_cfg))
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
eos_id = sp.piece_to_id("<eos>")
stop_ids = [
sp.piece_to_id(piece)
for piece in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]
if sp.piece_to_id(piece) >= 0
]
suppress_ids = [
idx
for idx in [sp.piece_to_id("<pad>"), sp.piece_to_id("<unk>"), sp.piece_to_id("<bos>")]
if idx >= 0
]
print("SOVYN chat. Type 'exit' to quit.")
while True:
user = input("you> ").strip()
if user.lower() in {"exit", "quit"}:
break
prompt = format_prompt(user)
ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long, device=device)
candidates = []
runs = max(1, args.best_of if args.temperature > 0 else 1)
for _ in range(runs):
out = model.generate(
ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
eos_id=eos_id,
stop_ids=stop_ids,
suppress_ids=suppress_ids,
)
answer = clean_answer(sp.decode(out[0].tolist()))
candidates.append(answer)
answer = max(candidates, key=lambda item: score_answer(user, item))
print(f"sovyn> {answer}")
if __name__ == "__main__":
main()
|