| | import gradio as gr |
| | import os |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling |
| | import torch |
| | import subprocess |
| |
|
| | def finetune(model_name, hf_token, upload_repo): |
| | os.environ["HF_TOKEN"] = hf_token |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) |
| | model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token) |
| |
|
| | |
| | dataset = load_dataset("rinna/llm-japanese-dataset-v1", split="train") |
| |
|
| | |
| | def tokenize_fn(example): |
| | return tokenizer(example["text"], truncation=True, max_length=512) |
| |
|
| | tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names) |
| |
|
| | |
| | data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| |
|
| | |
| | training_args = TrainingArguments( |
| | output_dir="./finetuned_model", |
| | per_device_train_batch_size=2, |
| | num_train_epochs=1, |
| | save_total_limit=1, |
| | logging_steps=10, |
| | push_to_hub=True, |
| | hub_model_id=upload_repo, |
| | hub_token=hf_token |
| | ) |
| |
|
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=tokenized_dataset, |
| | data_collator=data_collator |
| | ) |
| |
|
| | |
| | trainer.train() |
| |
|
| | |
| | trainer.push_to_hub() |
| |
|
| | return f"ファインチューニング完了!モデルは https://huggingface.co/{upload_repo} にアップロードされました。" |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# 日本語チャットモデル 簡易ファインチューニング") |
| |
|
| | model_name = gr.Textbox(label="元モデル名(例:rinna/japanese-gpt-neox-3.6b)") |
| | hf_token = gr.Textbox(label="Hugging Face トークン", type="password") |
| | upload_repo = gr.Textbox(label="アップロード先リポジトリ名(例:yourname/finetuned-chat-jp)") |
| |
|
| | start_btn = gr.Button("ファインチューニング開始") |
| | output = gr.Textbox(label="実行結果") |
| |
|
| | start_btn.click(finetune, inputs=[model_name, hf_token, upload_repo], outputs=output) |
| |
|
| | demo.launch() |