import argparse import re import sys from pathlib import Path import sentencepiece as spm import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) from sovyn import SovynConfig, SovynForCausalLM from sovyn.formatting import format_prompt def clean_answer(text: str) -> str: answer = text.split("")[-1].replace("", "").strip() for marker in ["", "", "", "", "", ""]: answer = answer.split(marker)[0].strip() sentences = re.findall(r".+?[.!?。](?=\s|$)", answer) if sentences: first = sentences[0].strip() generic_first = first in { "좋아.", "응.", "알겠어.", "가능해.", "그럴 수 있어.", "불안할 수 있어.", "많이 지쳤겠다.", } if generic_first and len(sentences) > 1: answer = " ".join(item.strip() for item in sentences[:2]) else: answer = first return answer def score_answer(user: str, answer: str) -> int: text = user.lower() score = 0 groups = [ (["피곤", "지쳤", "힘들", "기운", "복잡"], ["지쳤", "쉬", "힘들", "숨", "들어줄", "괜찮"]), (["누구", "정체", "sovyn"], ["sovyn", "모델", "ai", "학습", "기억"]), (["짧", "간단", "핵심", "요약"], ["핵심", "짧", "결론", "간단"]), ( ["1b", "120m", "300m", "모델", "학습", "키울", "저장공간", "저장 공간", "체크포인트", "부족"], ["모델", "학습", "120m", "300m", "1b", "키우", "실험", "체크포인트", "저장", "공간"], ), (["먹", "점심", "저녁", "음식"], ["김밥", "국밥", "덮밥", "샌드위치", "먹"]), (["공부", "집중", "미루", "루틴"], ["10분", "타이머", "문제", "시작", "공부"]), (["불안", "걱정", "초조"], ["불안", "걱정", "숨", "괜찮", "나눠"]), ] for user_words, answer_words in groups: if any(word in text for word in user_words): score += sum(2 for word in answer_words if word.lower() in answer.lower()) break if answer in {"안녕.", "좋아.", "응.", "좋은 생각이야.", "좋은 아침이야."}: score -= 3 if len(answer) < 4: score -= 2 if 8 <= len(answer) <= 80: score += 1 return score def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--tokenizer", required=True) parser.add_argument("--device", default="cuda") parser.add_argument("--max-new-tokens", type=int, default=96) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top-k", type=int, default=50) parser.add_argument("--best-of", type=int, default=4) args = parser.parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): device = "cpu" sp = spm.SentencePieceProcessor(model_file=args.tokenizer) ckpt = torch.load(args.checkpoint, map_location="cpu") model_cfg = ckpt["config"]["model"] if model_cfg["vocab_size"] != sp.get_piece_size(): raise ValueError( f"Checkpoint vocab_size={model_cfg['vocab_size']} but tokenizer has " f"{sp.get_piece_size()} pieces" ) model = SovynForCausalLM(SovynConfig(**model_cfg)) model.load_state_dict(ckpt["model"]) model.to(device) model.eval() eos_id = sp.piece_to_id("") stop_ids = [ sp.piece_to_id(piece) for piece in ["", "", "", "", "", ""] if sp.piece_to_id(piece) >= 0 ] suppress_ids = [ idx for idx in [sp.piece_to_id(""), sp.piece_to_id(""), sp.piece_to_id("")] if idx >= 0 ] print("SOVYN chat. Type 'exit' to quit.") while True: user = input("you> ").strip() if user.lower() in {"exit", "quit"}: break prompt = format_prompt(user) ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long, device=device) candidates = [] runs = max(1, args.best_of if args.temperature > 0 else 1) for _ in range(runs): out = model.generate( ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, eos_id=eos_id, stop_ids=stop_ids, suppress_ids=suppress_ids, ) answer = clean_answer(sp.decode(out[0].tolist())) candidates.append(answer) answer = max(candidates, key=lambda item: score_answer(user, item)) print(f"sovyn> {answer}") if __name__ == "__main__": main()