metadata
license: cc-by-nc-nd-4.0
base_model: skt/kogpt2-base-v2
tags:
- gpt2
- lora
- korean
- chatbot
language:
- ko
๋ชจ๋ธ ์ด๋ฆ
jihun-pae/kogpt2-chatbot-lora
๋ชจ๋ธ ์ค๋ช
- LoRA ํ๊ตญ์ด ์ฑ๋ด
๋ชจ๋ธ ์์ธ
- ๊ต์ก์ฉ ์ค์ต ๋ชจ๋ธ์ ๋๋ค.
LoRA ์ค์
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["c_attn", "c_proj", "c_fc"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
ํ์ต ์ค์
training_args = TrainingArguments(
output_dir="./lora_koqpt2_chatbot",
num_train_epochs=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=0.0002,
warmup_steps=100,
logging_steps=50,
eval_strategy="epoch",
eval_steps=100,
save_strategy="epoch",
save_steps=100,
load_best_model_at_end=True,
fp16=True,
report_to="none",
weight_decay=0.01,
)
ํ์ต ๊ฒฐ๊ณผ
์ฌ์ฉ ๋ฐฉ๋ฒ
# ํ
์คํธ
test_questions = [
"์๋
ํ์ธ์?",
"์ค๋ ๋ ์จ๊ฐ ์ด๋?",
"๋ฐฐ๊ณ ํ๋ฐ ๋ญ ๋จน์๊น?",
"์ฃผ๋ง์ ๋ญํ์ง?"
]
print("=== ์ฑ๋ด ํ
์คํธ ===")
for q in test_questions:
print(f"
์ง๋ฌธ: ๊ฐ์๋ ์ค์
์ฐ ๊ทผ์ฒ ๋ง์ง ์ข ์ถ์ฒํด์ฃผ์ธ์.")
print(f"๋ต๋ณ: ์ฌํ๊ฐ๋ด๋ ์ข์๊ฑฐ ๊ฐ์์.")