Spaces:
Paused
Paused
| """ | |
| Fine-tuning script for Reply Generation Model | |
| This script fine-tunes a language model to generate conversational replies based on: | |
| 1. Conversation context (user_text + partner_text) | |
| 2. Trigger (identified from conversation) | |
| 3. Move (deduced from trigger) | |
| 4. Output: Next appropriate response | |
| Usage: | |
| python finetune_model.py --data_path new_data_selected.csv --output_dir ./finetuned_reply_model | |
| """ | |
| import argparse | |
| import os | |
| import pandas as pd | |
| import torch | |
| from datasets import Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| EncoderDecoderModel, | |
| TrainingArguments, | |
| Trainer, | |
| BitsAndBytesConfig | |
| ) | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| def get_active_labels(row, prefix_cols): | |
| """Get active (value=1) labels from a row""" | |
| active = [] | |
| for col in prefix_cols: | |
| if row[col] == 1: | |
| # Remove prefix (e.g., "trigger_rapport_bid" -> "rapport_bid") | |
| label = col.replace("trigger_", "").replace("move_", "") | |
| active.append(label) | |
| return active if active else ["none"] | |
| def build_instruction(conversation: str, trigger: str, move: str, persona: str) -> str: | |
| base_lines = [ | |
| "Given this conversation between Male and Female, identify the trigger and suggest the appropriate move to continue the conversation naturally.", | |
| "", | |
| f"Conversation: {conversation}", | |
| f"Trigger: {trigger}", | |
| f"Move: {move}", | |
| "", | |
| ] | |
| if persona == "wingman": | |
| base_lines.append( | |
| "Persona: You are a confident Vietnamese wingman speaking on behalf of Male. " | |
| "Craft a short (<35 words), playful, and respectful reply from Male's perspective using 'anh' for self and 'em' for partner. " | |
| "Blend charm with the specified move while keeping it natural." | |
| ) | |
| else: | |
| base_lines.append( | |
| "Generate the next appropriate response from Male to Female. The reply should be from Male's perspective, responding to Female's message. " | |
| "Male should use \"anh\" (I) and \"em\" (you)." | |
| ) | |
| base_lines.append("") | |
| base_lines.append("Reply:") | |
| return "\n".join(base_lines) | |
| def prepare_training_data(df, use_history=True, persona="default"): | |
| """ | |
| Prepare data for fine-tuning. | |
| Nếu dataset đã có cột `male_reply` (build bởi build_reply_dataset.py) thì dùng: | |
| conversation, trigger, move, male_reply | |
| Làm ground-truth chuẩn cho reply từ phía Nam. | |
| Nếu không, fallback về logic cũ dựa trên user_text / partner_text (ít lý tưởng hơn). | |
| """ | |
| training_data = [] | |
| conversation_history = [] | |
| has_clean_reply = {"conversation", "trigger", "move", "male_reply"}.issubset(set(df.columns)) | |
| if has_clean_reply: | |
| for _, row in df.iterrows(): | |
| conversation = str(row.get("conversation", "") or "") | |
| trigger = str(row.get("trigger", "") or "neutral") | |
| move = str(row.get("move", "") or "neutral") | |
| reply = str(row.get("male_reply", "") or "").strip() | |
| if not conversation or not reply: | |
| continue | |
| prompt = build_instruction(conversation, trigger, move, persona) | |
| training_data.append( | |
| { | |
| "instruction": prompt, | |
| "input": "", | |
| "output": reply, | |
| } | |
| ) | |
| return training_data | |
| # Fallback: dùng dữ liệu gốc (kém lý tưởng hơn) | |
| trigger_cols = [col for col in df.columns if col.startswith("trigger_")] | |
| move_cols = [col for col in df.columns if col.startswith("move_")] | |
| for _, row in df.iterrows(): | |
| user_text = str(row["user_text"]) if pd.notna(row.get("user_text")) else "" | |
| partner_text = str(row["partner_text"]) if pd.notna(row.get("partner_text")) else "" | |
| if not partner_text or partner_text.strip() == "_": | |
| continue | |
| active_triggers = get_active_labels(row, trigger_cols) | |
| active_moves = get_active_labels(row, move_cols) | |
| trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral" | |
| move = active_moves[0] if active_moves[0] != "none" else "neutral" | |
| if use_history and conversation_history: | |
| history_str = "\n".join(conversation_history) | |
| if user_text and user_text.strip() != "_": | |
| current_turn = f"Male: {user_text}" | |
| conversation = f"{history_str}\n{current_turn}" | |
| else: | |
| conversation = history_str | |
| else: | |
| if user_text and user_text.strip() != "_": | |
| conversation = f"Male: {user_text} ||| Female: {partner_text}" | |
| else: | |
| conversation = f"Female: {partner_text}" | |
| prompt = build_instruction(conversation, trigger, move, persona) | |
| response = partner_text.strip() | |
| training_data.append( | |
| { | |
| "instruction": prompt, | |
| "input": "", | |
| "output": response, | |
| } | |
| ) | |
| if user_text and user_text.strip() != "_": | |
| conversation_history.append(f"Male: {user_text}") | |
| if partner_text and partner_text.strip() != "_": | |
| conversation_history.append(f"Female: {partner_text}") | |
| max_history = 4 | |
| if len(conversation_history) > max_history: | |
| conversation_history = conversation_history[-max_history:] | |
| return training_data | |
| def format_prompt(example, tokenizer): | |
| """Format the prompt for training""" | |
| instruction = example["instruction"] | |
| output = example["output"] | |
| text = f"{instruction}\n{output}{tokenizer.eos_token}" | |
| return {"text": text} | |
| def tokenize_function(examples, tokenizer): | |
| """Tokenize the examples""" | |
| texts = examples["text"] | |
| tokenized = tokenizer( | |
| texts, | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| tokenized["labels"] = tokenized["input_ids"].clone() | |
| return tokenized | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Fine-tune model for reply generation") | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| default="new_data_selected.csv", | |
| help="Path to training data CSV file" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./finetuned_reply_model", | |
| help="Output directory for fine-tuned model" | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="vinai/PhoGPT-4B-Chat", | |
| help="Base model name for fine-tuning" | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", | |
| type=int, | |
| default=1, # Reduced default for faster training on Spaces | |
| help="Number of training epochs" | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=2, | |
| help="Training batch size" | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=2e-4, | |
| help="Learning rate" | |
| ) | |
| parser.add_argument( | |
| "--use_history", | |
| action="store_true", | |
| help="Use conversation history in training" | |
| ) | |
| parser.add_argument( | |
| "--persona", | |
| type=str, | |
| default="default", | |
| choices=["default", "wingman"], | |
| help="Persona/instruction style for generation" | |
| ) | |
| parser.add_argument( | |
| "--model_arch", | |
| type=str, | |
| default="causal", | |
| choices=["causal", "encoder_decoder"], | |
| help="Model architecture type" | |
| ) | |
| args = parser.parse_args() | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load dataset | |
| print(f"Loading dataset from {args.data_path}...") | |
| df = pd.read_csv(args.data_path) | |
| print(f"Dataset shape: {df.shape}") | |
| # Prepare training data | |
| print("Preparing training data...") | |
| train_data = prepare_training_data(df, use_history=args.use_history, persona=args.persona) | |
| print(f"Total training examples: {len(train_data)}") | |
| # Convert to HuggingFace Dataset | |
| dataset = Dataset.from_list(train_data) | |
| split_dataset = dataset.train_test_split(test_size=0.1, seed=42) | |
| train_dataset = split_dataset["train"] | |
| val_dataset = split_dataset["test"] | |
| print(f"Train examples: {len(train_dataset)}") | |
| print(f"Validation examples: {len(val_dataset)}") | |
| # Load model and tokenizer | |
| print(f"Loading model: {args.model_name} ({args.model_arch})") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.cls_token | |
| # Try to configure quantization, fallback if triton not available | |
| use_quantization = False | |
| quant_config = None | |
| if args.model_arch == "causal": | |
| try: | |
| import bitsandbytes as bnb | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| use_quantization = True | |
| print("4-bit quantization enabled") | |
| except (ImportError, ModuleNotFoundError) as e: | |
| print(f"Warning: BitsAndBytesConfig not available ({e}), loading model without quantization...") | |
| model = None | |
| last_error = None | |
| def load_base_model(use_quant: bool): | |
| if args.model_arch == "encoder_decoder": | |
| model = EncoderDecoderModel.from_encoder_decoder_pretrained( | |
| args.model_name, | |
| args.model_name, | |
| tie_encoder_decoder=True, | |
| ) | |
| model.config.decoder_start_token_id = getattr(tokenizer, "bos_token_id", tokenizer.cls_token_id) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.config.vocab_size = model.config.encoder.vocab_size | |
| return model | |
| else: | |
| kwargs = dict( | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if use_quant and quant_config is not None: | |
| kwargs["quantization_config"] = quant_config | |
| return AutoModelForCausalLM.from_pretrained( | |
| args.model_name, | |
| **kwargs, | |
| ) | |
| if args.model_arch == "causal": | |
| if use_quantization: | |
| try: | |
| model = load_base_model(use_quant=True) | |
| print("Model loaded with 4-bit quantization") | |
| except Exception as e: | |
| last_error = e | |
| print(f"Failed to load with quantization: {e}") | |
| model = None | |
| if model is None: | |
| try: | |
| model = load_base_model(use_quant=False) | |
| print("Model loaded without quantization (may use more memory)") | |
| except Exception as e: | |
| if last_error: | |
| print(f"Original error: {last_error}") | |
| raise Exception(f"Failed to load model: {e}") | |
| else: | |
| model = load_base_model(use_quant=False) | |
| print("Encoder-decoder model loaded successfully!") | |
| # Format and tokenize datasets | |
| print("Formatting datasets...") | |
| train_dataset_formatted = train_dataset.map( | |
| lambda x: format_prompt(x, tokenizer), | |
| remove_columns=train_dataset.column_names | |
| ) | |
| val_dataset_formatted = val_dataset.map( | |
| lambda x: format_prompt(x, tokenizer), | |
| remove_columns=val_dataset.column_names | |
| ) | |
| print("Tokenizing datasets...") | |
| train_dataset_tokenized = train_dataset_formatted.map( | |
| lambda x: tokenize_function(x, tokenizer), | |
| batched=True, | |
| remove_columns=train_dataset_formatted.column_names | |
| ) | |
| val_dataset_tokenized = val_dataset_formatted.map( | |
| lambda x: tokenize_function(x, tokenizer), | |
| batched=True, | |
| remove_columns=val_dataset_formatted.column_names | |
| ) | |
| # Configure LoRA | |
| if args.model_arch == "causal": | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| else: | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM", | |
| ) | |
| # Prepare model for training (only if using quantization) | |
| if use_quantization: | |
| try: | |
| model = prepare_model_for_kbit_training(model) | |
| except Exception as e: | |
| print(f"Warning: prepare_model_for_kbit_training failed: {e}, continuing anyway...") | |
| model = get_peft_model(model, lora_config) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| all_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)") | |
| # Training arguments | |
| training_common_kwargs = dict( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.num_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| gradient_accumulation_steps=4, | |
| warmup_steps=100, | |
| learning_rate=args.learning_rate, | |
| fp16=True, | |
| logging_steps=10, | |
| eval_steps=100, | |
| save_steps=500, | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| report_to="none", | |
| remove_unused_columns=False, | |
| ) | |
| try: | |
| training_args = TrainingArguments( | |
| evaluation_strategy="steps", | |
| **training_common_kwargs, | |
| ) | |
| except TypeError: | |
| training_args = TrainingArguments( | |
| eval_strategy="steps", | |
| **training_common_kwargs, | |
| ) | |
| # Create Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset_tokenized, | |
| eval_dataset=val_dataset_tokenized, | |
| tokenizer=tokenizer, | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| print("Training completed!") | |
| # Save model | |
| trainer.save_model() | |
| tokenizer.save_pretrained(args.output_dir) | |
| print(f"Model saved to {args.output_dir}/") | |
| if __name__ == "__main__": | |
| main() | |