| import argparse |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import sentencepiece as spm |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from sovyn import SovynConfig, SovynForCausalLM |
| from sovyn.formatting import format_prompt |
|
|
|
|
| def clean_answer(text: str) -> str: |
| answer = text.split("<assistant>")[-1].replace("<eos>", "").strip() |
| for marker in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]: |
| answer = answer.split(marker)[0].strip() |
| sentences = re.findall(r".+?[.!?。](?=\s|$)", answer) |
| if sentences: |
| first = sentences[0].strip() |
| generic_first = first in { |
| "좋아.", |
| "응.", |
| "알겠어.", |
| "가능해.", |
| "그럴 수 있어.", |
| "불안할 수 있어.", |
| "많이 지쳤겠다.", |
| } |
| if generic_first and len(sentences) > 1: |
| answer = " ".join(item.strip() for item in sentences[:2]) |
| else: |
| answer = first |
| return answer |
|
|
|
|
| def score_answer(user: str, answer: str) -> int: |
| text = user.lower() |
| score = 0 |
| groups = [ |
| (["피곤", "지쳤", "힘들", "기운", "복잡"], ["지쳤", "쉬", "힘들", "숨", "들어줄", "괜찮"]), |
| (["누구", "정체", "sovyn"], ["sovyn", "모델", "ai", "학습", "기억"]), |
| (["짧", "간단", "핵심", "요약"], ["핵심", "짧", "결론", "간단"]), |
| ( |
| ["1b", "120m", "300m", "모델", "학습", "키울", "저장공간", "저장 공간", "체크포인트", "부족"], |
| ["모델", "학습", "120m", "300m", "1b", "키우", "실험", "체크포인트", "저장", "공간"], |
| ), |
| (["먹", "점심", "저녁", "음식"], ["김밥", "국밥", "덮밥", "샌드위치", "먹"]), |
| (["공부", "집중", "미루", "루틴"], ["10분", "타이머", "문제", "시작", "공부"]), |
| (["불안", "걱정", "초조"], ["불안", "걱정", "숨", "괜찮", "나눠"]), |
| ] |
| for user_words, answer_words in groups: |
| if any(word in text for word in user_words): |
| score += sum(2 for word in answer_words if word.lower() in answer.lower()) |
| break |
| if answer in {"안녕.", "좋아.", "응.", "좋은 생각이야.", "좋은 아침이야."}: |
| score -= 3 |
| if len(answer) < 4: |
| score -= 2 |
| if 8 <= len(answer) <= 80: |
| score += 1 |
| return score |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--tokenizer", required=True) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--max-new-tokens", type=int, default=96) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--best-of", type=int, default=4) |
| args = parser.parse_args() |
|
|
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| device = "cpu" |
|
|
| sp = spm.SentencePieceProcessor(model_file=args.tokenizer) |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| model_cfg = ckpt["config"]["model"] |
| if model_cfg["vocab_size"] != sp.get_piece_size(): |
| raise ValueError( |
| f"Checkpoint vocab_size={model_cfg['vocab_size']} but tokenizer has " |
| f"{sp.get_piece_size()} pieces" |
| ) |
| model = SovynForCausalLM(SovynConfig(**model_cfg)) |
| model.load_state_dict(ckpt["model"]) |
| model.to(device) |
| model.eval() |
|
|
| eos_id = sp.piece_to_id("<eos>") |
| stop_ids = [ |
| sp.piece_to_id(piece) |
| for piece in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"] |
| if sp.piece_to_id(piece) >= 0 |
| ] |
| suppress_ids = [ |
| idx |
| for idx in [sp.piece_to_id("<pad>"), sp.piece_to_id("<unk>"), sp.piece_to_id("<bos>")] |
| if idx >= 0 |
| ] |
| print("SOVYN chat. Type 'exit' to quit.") |
|
|
| while True: |
| user = input("you> ").strip() |
| if user.lower() in {"exit", "quit"}: |
| break |
| prompt = format_prompt(user) |
| ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long, device=device) |
| candidates = [] |
| runs = max(1, args.best_of if args.temperature > 0 else 1) |
| for _ in range(runs): |
| out = model.generate( |
| ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| eos_id=eos_id, |
| stop_ids=stop_ids, |
| suppress_ids=suppress_ids, |
| ) |
| answer = clean_answer(sp.decode(out[0].tolist())) |
| candidates.append(answer) |
| answer = max(candidates, key=lambda item: score_answer(user, item)) |
| print(f"sovyn> {answer}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|