llama-server / model /llama3.py
bigeco's picture
Update model/llama3.py
feaeb6d verified
import yaml
import argparse
from openai import OpenAI
# TODO 1: ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์กฐ์ • ํ•„์š”. (์‹คํ—˜ ํ…Œ์ŠคํŠธ ํ•„์š”)
# TODO 2: ๊ฒฐ๊ณผ๊ฐ€ ()->() ๋ฒ—์–ด๋‚  ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•ด ์ฒ˜๋ฆฌ ์ฝ”๋“œ ์ž‘์„ฑ ํ•„์š”.
# TODO 3: ๊ฒฐ๊ณผ ์—์„œ ์ž๋ชจ ๊ธฐ์ค€์œผ๋กœ ์ž˜๋ชป ๋ฐœ์Œํ•œ ๋ถ€๋ถ„ ์ถ”์ถœ ์ฝ”๋“œ ์ž‘์„ฑ ํ•„์š”.
# TODO 4: ๋‹ค์ค‘ ์ž…๋ ฅ ์ฒ˜๋ฆฌ(batch) ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•˜๋„๋ก. ํŒŒ์ผ์ด๋‚˜ ๋ฆฌ์ŠคํŠธ๋กœ ์—ฌ๋Ÿฌ user_input/correct_input ๋ฐ›์•„ ์ผ๊ด„ ์ฒ˜๋ฆฌ
# TODO 5: ์ƒ์„ฑ ๊ฒฐ๊ณผ ํ‰๊ฐ€ ์ง€ํ‘œ ํ•„์š”.
import yaml
import argparse
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
# TODO 1: ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์กฐ์ • ํ•„์š”. (์‹คํ—˜ ํ…Œ์ŠคํŠธ ํ•„์š”)
# TODO 2: ๊ฒฐ๊ณผ๊ฐ€ () -> () ๋’ท์–ด๋‚  ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•ด ์ฒ˜๋ฆฌ ์ฝ”๋“œ ์ž‘์„ฑ ํ•„์š”.
# TODO 3: ๊ฒฐ๊ณผ ์—์„œ ์ž๋ชจ ๊ธฐ์ค€์œผ๋กœ ์ž ๋ชป ๋ฐœ์Œ๋œ ๋ถ€๋ถ„ ์ถ”์ถœ ์ฝ”๋“œ ์ž‘์„ฑ ํ•„์š”.
# TODO 4: ๋‹ค์ค‘ ์ž…๋ ฅ ์ฒ˜๋ฆฌ(batch) ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•˜๋„๋ก. ํŒŒ์ผ์ด๋‚˜ ๋ฆฌ์ŠคํŠธ๋กœ ์—ฌ๋Ÿฌ user_input/correct_input ๋ฐ›์•„ ์ผ๊ด„
# TODO 5: ์„ฑ์žฅ ๊ฒฐ๊ณผ ํ‰๊ฐ€ ์ง€ํ‘œ ํ•„์š”.
class LLaMA3:
def __init__(self, config: str):
# hangul-romanize ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ดˆ๊ธฐํ™”
self.transliter = Transliter(academic)
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์„ค์ • (์‚ฌ์šฉํ•˜์ง€ ์•Š์ง€๋งŒ ๊ธฐ์กด ๊ตฌ์กฐ ์œ ์ง€)
self.prompt_template = config.get("prompt_template", "")
# ๋ชจ๋ธ ์„ค์ • (์‚ฌ์šฉํ•˜์ง€ ์•Š์ง€๋งŒ ๊ธฐ์กด ๊ตฌ์กฐ ์œ ์ง€)
self.model = config.get("model", {}).get("id", "")
def add_hyphens(self, korean_text):
"""์Œ์ ˆ๋ณ„๋กœ ๋ถ„๋ฆฌํ•ด์„œ ํ•˜์ดํ”ˆ ์ถ”๊ฐ€"""
syllables = list(korean_text)
romanized_syllables = []
for syllable in syllables:
romanized = self.transliter.translit(syllable)
romanized_syllables.append(romanized)
return '-'.join(romanized_syllables)
def generate(self, user_input: str, correct_input: str) -> str:
# user_input์„ ๋กœ๋งˆ์ž๋กœ ๋ณ€ํ™˜ (๊ด„ํ˜ธ ์ œ๊ฑฐ)
user_korean = user_input.strip('()')
user_romanized = self.add_hyphens(user_korean)
# correct_input์„ ๋กœ๋งˆ์ž๋กœ ๋ณ€ํ™˜ (๊ด„ํ˜ธ ์ œ๊ฑฐ)
correct_korean = correct_input.strip('()')
correct_romanized = self.add_hyphens(correct_korean)
# (user_romanized) -> (correct_romanized) ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜
result = f"({user_romanized})->({correct_romanized})"
return result
# def parse_args() -> argparse.Namespace:
# parser = argparse.ArgumentParser(description="LLaMA3 pronunciation correction pipeline.")
# parser.add_argument("--config_path", type=str, default="data/config/llama3.yaml", help="๋ชจ๋ธ ์„ค์ • ๋ฐ ํ”„๋กฌํ”„ํŠธ ์ •๋ณด๋ฅผ ๋‹ด์€ YAML ํŒŒ์ผ ๊ฒฝ๋กœ")
# parser.add_argument("--user_input", type=str, default="๋ฐ•๋ผ", help="์ž˜๋ชป ๋ฐœ์Œ๋œ ๋‹จ์–ด")
# parser.add_argument("--correct_input", type=str, default="๋ฐœ๋ผ", help="์ •ํ™•ํ•œ ๋ฐœ์Œ ๋‹จ์–ด")
# return parser.parse_args()
# def main():
# args = parse_args()
# # ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ
# with open(args.config_path, "r") as f:
# config = yaml.safe_load(f)
# # ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
# llama3 = LLaMA3(config)
# output = llama3.generate(args.user_input, args.correct_input)
# print(output)
# if __name__ == "__main__":
# main()