Gemma_models / Code /finetune_from_cpt.py
houssamboukhalfa's picture
Upload folder using huggingface_hub
f8dd4fe verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Fine-tuning pipeline from saved CPT model for any BERT-style model.
Loads the CPT weights and fine-tunes a classification head:
- Train on the training CSV (train.csv) that has `Commentaire client` and `Class` (1..9)
- Keep classes as 1..9 but the model uses 0..8 internally (label mapping saved to config)
Usage:
python finetune_from_cpt.py
Notes:
- This script uses the Hugging Face Trainer API.
- Adjust the epochs and batch sizes if you have a different GPU memory budget.
"""
import os
import json
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
)
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
from inspect import signature
# ---------------------------
# Paths & basic config
# ---------------------------
TRAIN_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
# Path to your saved CPT model
CPT_MODEL_PATH = "" # Change this to your CPT model path
# Output directory for fine-tuned model
FT_OUTPUT_DIR = "./telecom_arabert_large_full_pipeline"
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
# Label mapping (keep user-facing classes 1..9 but model indices 0..8)
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
NUM_LABELS = len(LABEL2ID)
# ---------------------------
# Finetune classification using CPT weights
# ---------------------------
print("\n=== Finetuning phase: load CPT weights and fine-tune classifier ===\n")
# Check if CPT model exists
if not os.path.exists(CPT_MODEL_PATH):
raise FileNotFoundError(f"CPT model not found at: {CPT_MODEL_PATH}")
print(f"Loading CPT model from: {CPT_MODEL_PATH}")
# Load train dataset with labels
print(f"Loading training CSV from: {TRAIN_FILE}")
train_ds = load_dataset("csv", data_files=TRAIN_FILE, split="train")
print(f"Train samples: {len(train_ds)} | Columns: {train_ds.column_names}")
# Prepare label mapping on train file (ensure handling of string/int)
def encode_train_labels(example):
c = example.get("Class")
if isinstance(c, str):
try:
c = int(c)
except Exception:
# Attempt to strip and convert
c = int(c.strip())
if c not in LABEL2ID:
raise ValueError(f"Unexpected class value in training data: {c}")
example["labels"] = LABEL2ID[c]
return example
train_ds = train_ds.map(encode_train_labels)
# Train/validation split
split = train_ds.train_test_split(test_size=0.1, seed=42)
train_split = split["train"]
eval_split = split["test"]
print("Train split size:", len(train_split), "Eval split size:", len(eval_split))
# Load tokenizer and model from CPT output
print("Loading tokenizer and model from CPT output for finetuning...")
ft_tokenizer = AutoTokenizer.from_pretrained(CPT_MODEL_PATH)
# Load sequence classification model initialized from the CPT weights
print("Loading AutoModelForSequenceClassification from CPT weights")
ft_model = AutoModelForSequenceClassification.from_pretrained(
CPT_MODEL_PATH,
num_labels=NUM_LABELS,
id2label={str(k): str(v) for k, v in ID2LABEL.items()},
label2id={str(v): k for k, v in LABEL2ID.items()},
)
ft_model = ft_model.to(DEVICE)
# Print model info
total_params = sum(p.numel() for p in ft_model.parameters())
trainable_params = sum(p.numel() for p in ft_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Tokenization function
def preprocess_classification(examples):
return ft_tokenizer(
examples["Commentaire client"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
)
train_split = train_split.map(preprocess_classification, batched=True, num_proc=4)
eval_split = eval_split.map(preprocess_classification, batched=True, num_proc=4)
train_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
eval_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
# Metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = accuracy_score(labels, preds)
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
precision_mi, recall_mi, f1_mi, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
metrics = {
'accuracy': acc,
'f1_weighted': f1_w,
'f1_macro': f1_m,
'f1_micro': f1_mi,
'precision_weighted': precision_w,
'recall_weighted': recall_w,
'precision_macro': precision_m,
'recall_macro': recall_m,
}
# per-class f1
per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)
for idx, class_name in ID2LABEL.items():
if idx < len(per_class_f1):
metrics[f'f1_class_{class_name}'] = float(per_class_f1[idx])
return metrics
# Training arguments for finetuning
# Reuse dynamic check for transformers TrainingArguments signature
ta_sig = signature(TrainingArguments.__init__)
ta_params = set(ta_sig.parameters.keys())
ft_base_kwargs = {
'output_dir': FT_OUTPUT_DIR,
'num_train_epochs': 100,
'per_device_train_batch_size': 32,
'per_device_eval_batch_size': 64,
'learning_rate': 1e-5,
'weight_decay': 0.01,
'warmup_ratio': 0.1,
'logging_steps': 50,
'save_total_limit': 2,
}
if 'bf16' in ta_params and torch.cuda.is_available() and hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
ft_base_kwargs['bf16'] = True
elif 'fp16' in ta_params and torch.cuda.is_available():
ft_base_kwargs['fp16'] = True
# Add evaluation_strategy if supported
if 'evaluation_strategy' in ta_params:
ft_base_kwargs['evaluation_strategy'] = 'epoch'
ft_base_kwargs['save_strategy'] = 'epoch'
ft_base_kwargs['load_best_model_at_end'] = True
ft_base_kwargs['metric_for_best_model'] = 'f1_weighted'
# Filter supported args
ft_filtered = {k: v for k, v in ft_base_kwargs.items() if k in ta_params}
ft_training_args = TrainingArguments(**ft_filtered)
# Trainer for finetuning
ft_trainer = Trainer(
model=ft_model,
args=ft_training_args,
train_dataset=train_split,
eval_dataset=eval_split,
tokenizer=ft_tokenizer,
compute_metrics=compute_metrics,
)
print("Starting finetuning on classification task...")
ft_trainer.train()
print("Finetuning finished. Saving finetuned model to:", FT_OUTPUT_DIR)
ft_trainer.save_model(FT_OUTPUT_DIR)
ft_tokenizer.save_pretrained(FT_OUTPUT_DIR)
# Update config with label mappings (so inference scripts can read cleanly)
config_path = os.path.join(FT_OUTPUT_DIR, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
cfg = json.load(f)
else:
cfg = {}
cfg['id2label'] = {str(k): str(v) for k, v in ID2LABEL.items()}
cfg['label2id'] = {str(v): k for k, v in LABEL2ID.items()}
cfg['num_labels'] = NUM_LABELS
cfg['problem_type'] = 'single_label_classification'
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
print("Saved label mappings to finetuned model config")
print('\nAll done. Finetuning completed.')
print('Finetuned classifier saved to:', FT_OUTPUT_DIR)