lovebird25 / finetune_model.py
Paul
update
0748ff8
"""
Fine-tuning script for Reply Generation Model
This script fine-tunes a language model to generate conversational replies based on:
1. Conversation context (user_text + partner_text)
2. Trigger (identified from conversation)
3. Move (deduced from trigger)
4. Output: Next appropriate response
Usage:
python finetune_model.py --data_path new_data_selected.csv --output_dir ./finetuned_reply_model
"""
import argparse
import os
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EncoderDecoderModel,
TrainingArguments,
Trainer,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
def get_active_labels(row, prefix_cols):
"""Get active (value=1) labels from a row"""
active = []
for col in prefix_cols:
if row[col] == 1:
# Remove prefix (e.g., "trigger_rapport_bid" -> "rapport_bid")
label = col.replace("trigger_", "").replace("move_", "")
active.append(label)
return active if active else ["none"]
def build_instruction(conversation: str, trigger: str, move: str, persona: str) -> str:
base_lines = [
"Given this conversation between Male and Female, identify the trigger and suggest the appropriate move to continue the conversation naturally.",
"",
f"Conversation: {conversation}",
f"Trigger: {trigger}",
f"Move: {move}",
"",
]
if persona == "wingman":
base_lines.append(
"Persona: You are a confident Vietnamese wingman speaking on behalf of Male. "
"Craft a short (<35 words), playful, and respectful reply from Male's perspective using 'anh' for self and 'em' for partner. "
"Blend charm with the specified move while keeping it natural."
)
else:
base_lines.append(
"Generate the next appropriate response from Male to Female. The reply should be from Male's perspective, responding to Female's message. "
"Male should use \"anh\" (I) and \"em\" (you)."
)
base_lines.append("")
base_lines.append("Reply:")
return "\n".join(base_lines)
def prepare_training_data(df, use_history=True, persona="default"):
"""
Prepare data for fine-tuning.
Nếu dataset đã có cột `male_reply` (build bởi build_reply_dataset.py) thì dùng:
conversation, trigger, move, male_reply
Làm ground-truth chuẩn cho reply từ phía Nam.
Nếu không, fallback về logic cũ dựa trên user_text / partner_text (ít lý tưởng hơn).
"""
training_data = []
conversation_history = []
has_clean_reply = {"conversation", "trigger", "move", "male_reply"}.issubset(set(df.columns))
if has_clean_reply:
for _, row in df.iterrows():
conversation = str(row.get("conversation", "") or "")
trigger = str(row.get("trigger", "") or "neutral")
move = str(row.get("move", "") or "neutral")
reply = str(row.get("male_reply", "") or "").strip()
if not conversation or not reply:
continue
prompt = build_instruction(conversation, trigger, move, persona)
training_data.append(
{
"instruction": prompt,
"input": "",
"output": reply,
}
)
return training_data
# Fallback: dùng dữ liệu gốc (kém lý tưởng hơn)
trigger_cols = [col for col in df.columns if col.startswith("trigger_")]
move_cols = [col for col in df.columns if col.startswith("move_")]
for _, row in df.iterrows():
user_text = str(row["user_text"]) if pd.notna(row.get("user_text")) else ""
partner_text = str(row["partner_text"]) if pd.notna(row.get("partner_text")) else ""
if not partner_text or partner_text.strip() == "_":
continue
active_triggers = get_active_labels(row, trigger_cols)
active_moves = get_active_labels(row, move_cols)
trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral"
move = active_moves[0] if active_moves[0] != "none" else "neutral"
if use_history and conversation_history:
history_str = "\n".join(conversation_history)
if user_text and user_text.strip() != "_":
current_turn = f"Male: {user_text}"
conversation = f"{history_str}\n{current_turn}"
else:
conversation = history_str
else:
if user_text and user_text.strip() != "_":
conversation = f"Male: {user_text} ||| Female: {partner_text}"
else:
conversation = f"Female: {partner_text}"
prompt = build_instruction(conversation, trigger, move, persona)
response = partner_text.strip()
training_data.append(
{
"instruction": prompt,
"input": "",
"output": response,
}
)
if user_text and user_text.strip() != "_":
conversation_history.append(f"Male: {user_text}")
if partner_text and partner_text.strip() != "_":
conversation_history.append(f"Female: {partner_text}")
max_history = 4
if len(conversation_history) > max_history:
conversation_history = conversation_history[-max_history:]
return training_data
def format_prompt(example, tokenizer):
"""Format the prompt for training"""
instruction = example["instruction"]
output = example["output"]
text = f"{instruction}\n{output}{tokenizer.eos_token}"
return {"text": text}
def tokenize_function(examples, tokenizer):
"""Tokenize the examples"""
texts = examples["text"]
tokenized = tokenizer(
texts,
truncation=True,
max_length=512,
padding="max_length",
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
def main():
parser = argparse.ArgumentParser(description="Fine-tune model for reply generation")
parser.add_argument(
"--data_path",
type=str,
default="new_data_selected.csv",
help="Path to training data CSV file"
)
parser.add_argument(
"--output_dir",
type=str,
default="./finetuned_reply_model",
help="Output directory for fine-tuned model"
)
parser.add_argument(
"--model_name",
type=str,
default="vinai/PhoGPT-4B-Chat",
help="Base model name for fine-tuning"
)
parser.add_argument(
"--num_epochs",
type=int,
default=1, # Reduced default for faster training on Spaces
help="Number of training epochs"
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Training batch size"
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-4,
help="Learning rate"
)
parser.add_argument(
"--use_history",
action="store_true",
help="Use conversation history in training"
)
parser.add_argument(
"--persona",
type=str,
default="default",
choices=["default", "wingman"],
help="Persona/instruction style for generation"
)
parser.add_argument(
"--model_arch",
type=str,
default="causal",
choices=["causal", "encoder_decoder"],
help="Model architecture type"
)
args = parser.parse_args()
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load dataset
print(f"Loading dataset from {args.data_path}...")
df = pd.read_csv(args.data_path)
print(f"Dataset shape: {df.shape}")
# Prepare training data
print("Preparing training data...")
train_data = prepare_training_data(df, use_history=args.use_history, persona=args.persona)
print(f"Total training examples: {len(train_data)}")
# Convert to HuggingFace Dataset
dataset = Dataset.from_list(train_data)
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
print(f"Train examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")
# Load model and tokenizer
print(f"Loading model: {args.model_name} ({args.model_arch})")
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.cls_token
# Try to configure quantization, fallback if triton not available
use_quantization = False
quant_config = None
if args.model_arch == "causal":
try:
import bitsandbytes as bnb
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
use_quantization = True
print("4-bit quantization enabled")
except (ImportError, ModuleNotFoundError) as e:
print(f"Warning: BitsAndBytesConfig not available ({e}), loading model without quantization...")
model = None
last_error = None
def load_base_model(use_quant: bool):
if args.model_arch == "encoder_decoder":
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
args.model_name,
args.model_name,
tie_encoder_decoder=True,
)
model.config.decoder_start_token_id = getattr(tokenizer, "bos_token_id", tokenizer.cls_token_id)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size
return model
else:
kwargs = dict(
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
if use_quant and quant_config is not None:
kwargs["quantization_config"] = quant_config
return AutoModelForCausalLM.from_pretrained(
args.model_name,
**kwargs,
)
if args.model_arch == "causal":
if use_quantization:
try:
model = load_base_model(use_quant=True)
print("Model loaded with 4-bit quantization")
except Exception as e:
last_error = e
print(f"Failed to load with quantization: {e}")
model = None
if model is None:
try:
model = load_base_model(use_quant=False)
print("Model loaded without quantization (may use more memory)")
except Exception as e:
if last_error:
print(f"Original error: {last_error}")
raise Exception(f"Failed to load model: {e}")
else:
model = load_base_model(use_quant=False)
print("Encoder-decoder model loaded successfully!")
# Format and tokenize datasets
print("Formatting datasets...")
train_dataset_formatted = train_dataset.map(
lambda x: format_prompt(x, tokenizer),
remove_columns=train_dataset.column_names
)
val_dataset_formatted = val_dataset.map(
lambda x: format_prompt(x, tokenizer),
remove_columns=val_dataset.column_names
)
print("Tokenizing datasets...")
train_dataset_tokenized = train_dataset_formatted.map(
lambda x: tokenize_function(x, tokenizer),
batched=True,
remove_columns=train_dataset_formatted.column_names
)
val_dataset_tokenized = val_dataset_formatted.map(
lambda x: tokenize_function(x, tokenizer),
batched=True,
remove_columns=val_dataset_formatted.column_names
)
# Configure LoRA
if args.model_arch == "causal":
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
else:
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM",
)
# Prepare model for training (only if using quantization)
if use_quantization:
try:
model = prepare_model_for_kbit_training(model)
except Exception as e:
print(f"Warning: prepare_model_for_kbit_training failed: {e}, continuing anyway...")
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
# Training arguments
training_common_kwargs = dict(
output_dir=args.output_dir,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=4,
warmup_steps=100,
learning_rate=args.learning_rate,
fp16=True,
logging_steps=10,
eval_steps=100,
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none",
remove_unused_columns=False,
)
try:
training_args = TrainingArguments(
evaluation_strategy="steps",
**training_common_kwargs,
)
except TypeError:
training_args = TrainingArguments(
eval_strategy="steps",
**training_common_kwargs,
)
# Create Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset_tokenized,
eval_dataset=val_dataset_tokenized,
tokenizer=tokenizer,
)
print("Starting training...")
trainer.train()
print("Training completed!")
# Save model
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
print(f"Model saved to {args.output_dir}/")
if __name__ == "__main__":
main()