ZhiyangQi97's picture
Update README.md
53dfd76 verified
|
raw
history blame
3.36 kB
metadata
base_model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3
license:
  - llama3.1
  - gemma
language:
  - ja
  - en
pipeline_tag: text-generation
tags:
  - counseling
  - dialogue-system
datasets:
  - UEC-InabaLab/KokoroChat

🧠 KokoroChat-High (LoRA Adapter for Japanese Counseling Dialogue)

This repository contains the LoRA adapter weights for KokoroChat-High, a version of the KokoroChat model fine-tuned on high-feedback counseling dialogues (client score ≥ 70 and ≤ 98) from the KokoroChat dataset.

The base model is tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3, and this adapter enhances it for generating high-quality, empathetic Japanese counseling responses.


💡 What is "KokoroChat-High"?

  • ✅ Trained on 2,601 dialogues
  • ✅ All sessions have client feedback scores between 70 and 98
  • ✅ Represents high-quality, successful counseling interactions

🧾 Model Details


⚙️ Usage Instructions (LoRA Adapter)

This repository only contains the adapter weights.
You must load the original base model and then apply this adapter.

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# === Base + Adapter Paths ===
base_model_id = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
adapter_id = "your-username/kokorochat-lora"

# === Load Tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

# === Load Base Model ===
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)

# === Load & Merge LoRA ===
model = PeftModel.from_pretrained(base_model, adapter_id)
model = model.merge_and_unload()

🧪 Example Inference

messages = [
    {"role": "system", "content": "心理カウンセリングの会話において、対話履歴を考慮し、カウンセラーとして適切に応答してください。"},
    {"role": "user", "content": "最近、家族との関係がうまくいかず、気持ちが落ち込んでいます。"}
]

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)

output = model.generate(
    input_ids,
    max_new_tokens=512,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id
)

print(tokenizer.decode(output[0][input_ids.shape[-1]:], skip_special_tokens=True))

🔗 Related