Gemma_models / Code /train_marbertv2.py
houssamboukhalfa's picture
Upload folder using huggingface_hub
f8dd4fe verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Fine-tune UBC-NLP/MARBERTv2 for Arabic telecom customer comment classification.
Dataset (CSV):
/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv
Columns:
Commentaire client: str (text)
Class: int (label - values 1 through 9)
Model:
- MARBERTv2 encoder
- Classification head for multi-class prediction (9 classes)
"""
import os
import numpy as np
import torch
from inspect import signature
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
)
# Slight speed boost on Ampere GPUs
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
# -------------------------------------------------------------------
# 1. Paths & config
# -------------------------------------------------------------------
DATA_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
MODEL_NAME = "UBC-NLP/MARBERTv2"
OUTPUT_DIR = "./telecom_marbertv2_final"
MAX_LENGTH = 256
# Define label mapping - classes are 1-9
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
ID2LABEL = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
NUM_LABELS = 9
# -------------------------------------------------------------------
# 2. Dataset loading
# -------------------------------------------------------------------
print("Loading telecom dataset from CSV...")
dataset = load_dataset(
"csv",
data_files=DATA_FILE,
split="train",
)
print("Sample example:", dataset[0])
print(f"Total examples: {len(dataset)}")
print(f"Number of classes: {NUM_LABELS}")
print("Label mapping (class -> model index):", LABEL2ID)
print("Inverse mapping (model index -> class):", ID2LABEL)
def encode_labels(example):
"""Convert class (1-9) to model label index (0-8)."""
class_val = example["Class"]
# Handle both int and string types
if isinstance(class_val, str):
class_val = int(class_val)
if class_val not in LABEL2ID:
raise ValueError(f"Unknown class: {class_val}. Expected 1-9.")
example["labels"] = LABEL2ID[class_val]
return example
dataset = dataset.map(encode_labels)
# Train/val split (90/10)
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
print("Train size:", len(train_dataset))
print("Eval size:", len(eval_dataset))
# -------------------------------------------------------------------
# 3. Tokenization
# -------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess_function(examples):
return tokenizer(
examples["Commentaire client"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
)
train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=4)
eval_dataset = eval_dataset.map(preprocess_function, batched=True, num_proc=4)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
# -------------------------------------------------------------------
# 4. Model - Using AutoModelForSequenceClassification
# -------------------------------------------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
id2label=ID2LABEL,
label2id=LABEL2ID,
)
print("Model initialized with classification head")
print(f"Number of labels: {NUM_LABELS}")
print(f"Classes: {list(ID2LABEL.values())}")
# -------------------------------------------------------------------
# 5. Metrics
# -------------------------------------------------------------------
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
# Overall metrics
accuracy = accuracy_score(labels, predictions)
# Weighted average (accounts for class imbalance)
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
labels, predictions, average='weighted', zero_division=0
)
# Macro average (treats all classes equally)
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
labels, predictions, average='macro', zero_division=0
)
metrics = {
'accuracy': accuracy,
'f1_weighted': f1_w,
'f1_macro': f1_m,
'precision_weighted': precision_w,
'recall_weighted': recall_w,
'precision_macro': precision_m,
'recall_macro': recall_m,
}
# Per-class F1 scores
per_class_f1 = f1_score(labels, predictions, average=None, zero_division=0)
for idx in range(NUM_LABELS):
class_name = ID2LABEL[idx]
if idx < len(per_class_f1):
metrics[f'f1_class_{class_name}'] = per_class_f1[idx]
return metrics
# -------------------------------------------------------------------
# 6. TrainingArguments (old/new transformers compatible)
# -------------------------------------------------------------------
ta_sig = signature(TrainingArguments.__init__)
ta_params = set(ta_sig.parameters.keys())
is_bf16_supported = (
torch.cuda.is_available()
and hasattr(torch.cuda, "is_bf16_supported")
and torch.cuda.is_bf16_supported()
)
use_bf16 = bool(is_bf16_supported)
use_fp16 = not use_bf16
print(f"bf16 supported: {is_bf16_supported} -> using bf16={use_bf16}, fp16={use_fp16}")
base_kwargs = {
"output_dir": OUTPUT_DIR,
"num_train_epochs": 10,
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 64,
"learning_rate": 1e-4,
"weight_decay": 0.02,
"warmup_ratio": 0.1,
"logging_steps": 50,
"save_total_limit": 2,
"dataloader_num_workers": 4,
}
# Mixed precision flags if supported
if "bf16" in ta_params:
base_kwargs["bf16"] = use_bf16
if "fp16" in ta_params:
base_kwargs["fp16"] = use_fp16
# Handle evaluation_strategy compatibility
if "evaluation_strategy" in ta_params:
base_kwargs["evaluation_strategy"] = "epoch"
if "save_strategy" in ta_params:
base_kwargs["save_strategy"] = "epoch"
if "logging_strategy" in ta_params:
base_kwargs["logging_strategy"] = "steps"
if "load_best_model_at_end" in ta_params:
base_kwargs["load_best_model_at_end"] = True
if "metric_for_best_model" in ta_params:
base_kwargs["metric_for_best_model"] = "f1_weighted"
if "greater_is_better" in ta_params:
base_kwargs["greater_is_better"] = True
if "report_to" in ta_params:
base_kwargs["report_to"] = "none"
else:
if "report_to" in ta_params:
base_kwargs["report_to"] = "none"
print("[TrainingArguments] Old transformers version: no evaluation_strategy argument. Using simple setup.")
filtered_kwargs = {}
for k, v in base_kwargs.items():
if k in ta_params:
filtered_kwargs[k] = v
else:
print(f"[TrainingArguments] Skipping unsupported arg: {k}={v}")
training_args = TrainingArguments(**filtered_kwargs)
# -------------------------------------------------------------------
# 7. Trainer
# -------------------------------------------------------------------
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# -------------------------------------------------------------------
# 8. Train & eval
# -------------------------------------------------------------------
if __name__ == "__main__":
print("Starting telecom classification training...")
trainer.train()
print("Evaluating on validation split...")
metrics = trainer.evaluate()
print("Validation metrics:", metrics)
print("Saving final model & tokenizer...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Label mappings saved in config:")
print(f" ID to Label: {ID2LABEL}")
print(f" Label to ID: {LABEL2ID}")
# Quick sanity-check inference
example_texts = [
"الخدمة ممتازة جدا وسريعة",
"سيء للغاية ولا يستجيبون",
"متوسط الجودة"
]
inputs = tokenizer(
example_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH
).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().numpy()
predictions = np.argmax(logits, axis=-1)
print("\nSanity-check predictions:")
for text, pred_idx in zip(example_texts, predictions):
pred_class = ID2LABEL[pred_idx]
print(f"Text: {text}")
print(f" -> Predicted Class: {pred_class}")
print()
print("Training complete and model saved to:", OUTPUT_DIR)