Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Code/cpt_mlm_then_finetune.py +315 -0
- Code/finetune_from_cpt.py +221 -0
- Code/finetune_gemma3_classification.py +686 -0
- Code/inference_camelbert.py +268 -0
- Code/inference_dziribert.py +267 -0
- Code/inference_gemma3.py +409 -0
- Code/inference_marbertv2.py +275 -0
- Code/inference_marbertv2_cpt_ft.py +268 -0
- Code/train_arabert.py +293 -0
- Code/train_dziribert.py +313 -0
- Code/train_marbertv2.py +293 -0
- Code/voting.ipynb +119 -0
- requirements.txt +54 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/config.json +48 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/model.safetensors +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/optimizer.pt +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/rng_state.pth +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/scheduler.pt +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/special_tokens_map.json +37 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/tokenizer.json +0 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/tokenizer_config.json +66 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/trainer_state.json +524 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/training_args.bin +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3500/vocab.txt +0 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/config.json +48 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/model.safetensors +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/optimizer.pt +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/rng_state.pth +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/scheduler.pt +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/special_tokens_map.json +37 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/tokenizer.json +0 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/tokenizer_config.json +66 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/trainer_state.json +531 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/training_args.bin +3 -0
- telecom_camelbert_cpt_ft/checkpoint-3570/vocab.txt +0 -0
- telecom_camelbert_cpt_ft/config.json +49 -0
- telecom_camelbert_cpt_ft/model.safetensors +3 -0
- telecom_camelbert_cpt_ft/special_tokens_map.json +37 -0
- telecom_camelbert_cpt_ft/tokenizer.json +0 -0
- telecom_camelbert_cpt_ft/tokenizer_config.json +66 -0
- telecom_camelbert_cpt_ft/training_args.bin +3 -0
- telecom_camelbert_cpt_ft/vocab.txt +0 -0
- telecom_dziribert_final/checkpoint-8000/config.json +48 -0
- telecom_dziribert_final/checkpoint-8000/model.safetensors +3 -0
- telecom_dziribert_final/checkpoint-8000/optimizer.pt +3 -0
- telecom_dziribert_final/checkpoint-8000/rng_state.pth +3 -0
- telecom_dziribert_final/checkpoint-8000/scheduler.pt +3 -0
- telecom_dziribert_final/checkpoint-8000/special_tokens_map.json +7 -0
- telecom_dziribert_final/checkpoint-8000/tokenizer.json +0 -0
- telecom_dziribert_final/checkpoint-8000/tokenizer_config.json +59 -0
Code/cpt_mlm_then_finetune.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
CPT (continued pretraining) + finetune pipeline for Any model .
|
| 5 |
+
|
| 6 |
+
1) Use the text from the test CSV (test_file.csv) to do masked language model (MLM) training
|
| 7 |
+
on any model (CPT step). This adapts the LM to the test-domain text.
|
| 8 |
+
|
| 9 |
+
2) Load the CPT weights and fine-tune a classification head (same task as `traim.py`):
|
| 10 |
+
- Train on the training CSV (train.csv) that has `Commentaire client` and `Class` (1..9)
|
| 11 |
+
- Keep classes as 1..9 but the model uses 0..8 internally (label mapping saved to config)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python cpt_mlm_then_finetune.py
|
| 15 |
+
|
| 16 |
+
Notes:
|
| 17 |
+
- This script uses the Hugging Face Trainer API for both phases.
|
| 18 |
+
- Adjust the TRAIN/MLM epochs and batch sizes if you have a different GPU memory budget.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import json
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from datasets import load_dataset
|
| 26 |
+
from transformers import (
|
| 27 |
+
AutoTokenizer,
|
| 28 |
+
AutoConfig,
|
| 29 |
+
AutoModelForMaskedLM,
|
| 30 |
+
AutoModelForSequenceClassification,
|
| 31 |
+
DataCollatorForLanguageModeling,
|
| 32 |
+
TrainingArguments,
|
| 33 |
+
Trainer,
|
| 34 |
+
)
|
| 35 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 36 |
+
from inspect import signature
|
| 37 |
+
|
| 38 |
+
# ---------------------------
|
| 39 |
+
# Paths & basic config
|
| 40 |
+
# ---------------------------
|
| 41 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/test_file.csv"
|
| 42 |
+
TRAIN_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
|
| 43 |
+
BASE_MODEL = "alger-ia/dziribert_sentiment"
|
| 44 |
+
|
| 45 |
+
CPT_OUTPUT_DIR = "./dziribert" # UBC-NLP/MARBERTv2 , aubmindlab/bert-large-arabertv2 , alger-ia/dziribert , CAMeL-Lab/bert-base-arabic-camelbert-msa
|
| 46 |
+
FT_OUTPUT_DIR = "./telecom_dziribert_cpt_ft"
|
| 47 |
+
|
| 48 |
+
MAX_LENGTH = 512
|
| 49 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
print(f"Device: {DEVICE}")
|
| 51 |
+
|
| 52 |
+
# Label mapping (keep user-facing classes 1..9 but model indices 0..8)
|
| 53 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 54 |
+
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
|
| 55 |
+
NUM_LABELS = len(LABEL2ID)
|
| 56 |
+
|
| 57 |
+
# ---------------------------
|
| 58 |
+
# 1) CPT: Masked LM training on test text
|
| 59 |
+
# ---------------------------
|
| 60 |
+
print("\n=== CPT (MLM) phase: load test texts and continue pretraining the LM ===\n")
|
| 61 |
+
|
| 62 |
+
# Load test CSV (contains column 'Commentaire client')
|
| 63 |
+
print(f"Loading test CSV from: {TEST_FILE}")
|
| 64 |
+
test_ds = load_dataset("csv", data_files=TEST_FILE, split="train")
|
| 65 |
+
print(f"Test samples: {len(test_ds)} | Columns: {test_ds.column_names}")
|
| 66 |
+
|
| 67 |
+
# Extract texts column name heuristics
|
| 68 |
+
text_col = None
|
| 69 |
+
for c in test_ds.column_names:
|
| 70 |
+
if c.lower() in {"commentaire client", "commentaire_client", "comment", "text", "commentaire"}:
|
| 71 |
+
text_col = c
|
| 72 |
+
break
|
| 73 |
+
if text_col is None:
|
| 74 |
+
# fallback to the last column
|
| 75 |
+
text_col = test_ds.column_names[-1]
|
| 76 |
+
print(f"Using text column for MLM: {text_col}")
|
| 77 |
+
|
| 78 |
+
# Load tokenizer and model for MLM
|
| 79 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 80 |
+
|
| 81 |
+
# Ensure tokenizer has a mask token
|
| 82 |
+
if tokenizer.mask_token is None:
|
| 83 |
+
# BERT-style tokenizers typically have [MASK]
|
| 84 |
+
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
| 85 |
+
print("Added [MASK] token to tokenizer")
|
| 86 |
+
|
| 87 |
+
print("Loading base model for MLM:", BASE_MODEL)
|
| 88 |
+
mlm_model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL)
|
| 89 |
+
mlm_model.resize_token_embeddings(len(tokenizer))
|
| 90 |
+
mlm_model = mlm_model.to(DEVICE)
|
| 91 |
+
|
| 92 |
+
# Tokenize for MLM using simple line-by-line tokenization
|
| 93 |
+
def tokenize_mlm(examples):
|
| 94 |
+
texts = examples[text_col]
|
| 95 |
+
return tokenizer(
|
| 96 |
+
texts,
|
| 97 |
+
truncation=True,
|
| 98 |
+
padding="max_length",
|
| 99 |
+
max_length=MAX_LENGTH,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print("Tokenizing test dataset for MLM...")
|
| 103 |
+
mlm_tokenized = test_ds.map(tokenize_mlm, batched=True, remove_columns=test_ds.column_names)
|
| 104 |
+
|
| 105 |
+
# Data collator for dynamic masking
|
| 106 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
| 107 |
+
|
| 108 |
+
# Training args for CPT (adjust epochs/batch as needed)
|
| 109 |
+
cpt_train_args = TrainingArguments(
|
| 110 |
+
output_dir=CPT_OUTPUT_DIR,
|
| 111 |
+
overwrite_output_dir=True,
|
| 112 |
+
num_train_epochs=38,
|
| 113 |
+
per_device_train_batch_size=32,
|
| 114 |
+
save_total_limit=2,
|
| 115 |
+
save_strategy="epoch",
|
| 116 |
+
logging_steps=200,
|
| 117 |
+
fp16=torch.cuda.is_available(),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Trainer for MLM
|
| 121 |
+
cpt_trainer = Trainer(
|
| 122 |
+
model=mlm_model,
|
| 123 |
+
args=cpt_train_args,
|
| 124 |
+
train_dataset=mlm_tokenized,
|
| 125 |
+
data_collator=data_collator,
|
| 126 |
+
tokenizer=tokenizer,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
print("Starting CPT (MLM) training...")
|
| 130 |
+
cpt_trainer.train()
|
| 131 |
+
print("CPT finished. Saving CPT model to:", CPT_OUTPUT_DIR)
|
| 132 |
+
|
| 133 |
+
# Save model & tokenizer
|
| 134 |
+
cpt_trainer.save_model(CPT_OUTPUT_DIR)
|
| 135 |
+
tokenizer.save_pretrained(CPT_OUTPUT_DIR)
|
| 136 |
+
|
| 137 |
+
# Save an informative config with label mappings too (so downstream scripts can read them)
|
| 138 |
+
config_path = os.path.join(CPT_OUTPUT_DIR, "config.json")
|
| 139 |
+
if os.path.exists(config_path):
|
| 140 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 141 |
+
config_data = json.load(f)
|
| 142 |
+
else:
|
| 143 |
+
config_data = {}
|
| 144 |
+
|
| 145 |
+
# id2label as strings mapping index->original class (1..9), label2id as str -> int
|
| 146 |
+
config_data["id2label"] = {str(k): str(v) for k, v in ID2LABEL.items()} # e.g. "0":"1"
|
| 147 |
+
config_data["label2id"] = {str(v): k for k, v in LABEL2ID.items()} # e.g. "1":0
|
| 148 |
+
config_data["num_labels"] = NUM_LABELS
|
| 149 |
+
config_data["problem_type"] = "single_label_classification"
|
| 150 |
+
|
| 151 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
| 152 |
+
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
| 153 |
+
|
| 154 |
+
print("Saved config with label mappings to CPT output dir")
|
| 155 |
+
|
| 156 |
+
# ---------------------------
|
| 157 |
+
# 2) Finetune classification using CPT weights
|
| 158 |
+
# ---------------------------
|
| 159 |
+
print("\n=== Finetuning phase: load CPT weights and fine-tune classifier ===\n")
|
| 160 |
+
|
| 161 |
+
# Load train dataset with labels
|
| 162 |
+
print(f"Loading training CSV from: {TRAIN_FILE}")
|
| 163 |
+
train_ds = load_dataset("csv", data_files=TRAIN_FILE, split="train")
|
| 164 |
+
print(f"Train samples: {len(train_ds)} | Columns: {train_ds.column_names}")
|
| 165 |
+
|
| 166 |
+
# Prepare label mapping on train file (ensure handling of string/int)
|
| 167 |
+
def encode_train_labels(example):
|
| 168 |
+
c = example.get("Class")
|
| 169 |
+
if isinstance(c, str):
|
| 170 |
+
try:
|
| 171 |
+
c = int(c)
|
| 172 |
+
except Exception:
|
| 173 |
+
# Attempt to strip and convert
|
| 174 |
+
c = int(c.strip())
|
| 175 |
+
if c not in LABEL2ID:
|
| 176 |
+
raise ValueError(f"Unexpected class value in training data: {c}")
|
| 177 |
+
example["labels"] = LABEL2ID[c]
|
| 178 |
+
return example
|
| 179 |
+
|
| 180 |
+
train_ds = train_ds.map(encode_train_labels)
|
| 181 |
+
|
| 182 |
+
# Train/validation split
|
| 183 |
+
split = train_ds.train_test_split(test_size=0.1, seed=42)
|
| 184 |
+
train_split = split["train"]
|
| 185 |
+
eval_split = split["test"]
|
| 186 |
+
print("Train split size:", len(train_split), "Eval split size:", len(eval_split))
|
| 187 |
+
|
| 188 |
+
# Load tokenizer and model from CPT output
|
| 189 |
+
print("Loading tokenizer and model from CPT output for finetuning...")
|
| 190 |
+
ft_tokenizer = AutoTokenizer.from_pretrained(CPT_OUTPUT_DIR)
|
| 191 |
+
|
| 192 |
+
# Load sequence classification model initialized from the CPT weights
|
| 193 |
+
print("Loading AutoModelForSequenceClassification from CPT weights")
|
| 194 |
+
ft_model = AutoModelForSequenceClassification.from_pretrained(
|
| 195 |
+
CPT_OUTPUT_DIR,
|
| 196 |
+
num_labels=NUM_LABELS,
|
| 197 |
+
id2label={str(k): str(v) for k, v in ID2LABEL.items()},
|
| 198 |
+
label2id={str(v): k for k, v in LABEL2ID.items()},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
ft_model = ft_model.to(DEVICE)
|
| 202 |
+
|
| 203 |
+
# Tokenization function matches traim.py
|
| 204 |
+
def preprocess_classification(examples):
|
| 205 |
+
return ft_tokenizer(
|
| 206 |
+
examples["Commentaire client"],
|
| 207 |
+
padding="max_length",
|
| 208 |
+
truncation=True,
|
| 209 |
+
max_length=MAX_LENGTH,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
train_split = train_split.map(preprocess_classification, batched=True, num_proc=4)
|
| 213 |
+
eval_split = eval_split.map(preprocess_classification, batched=True, num_proc=4)
|
| 214 |
+
|
| 215 |
+
train_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 216 |
+
eval_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 217 |
+
|
| 218 |
+
# Metrics same as traim.py
|
| 219 |
+
def compute_metrics(eval_pred):
|
| 220 |
+
logits, labels = eval_pred
|
| 221 |
+
preds = np.argmax(logits, axis=-1)
|
| 222 |
+
acc = accuracy_score(labels, preds)
|
| 223 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
|
| 224 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
|
| 225 |
+
precision_mi, recall_mi, f1_mi, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
|
| 226 |
+
metrics = {
|
| 227 |
+
'accuracy': acc,
|
| 228 |
+
'f1_weighted': f1_w,
|
| 229 |
+
'f1_macro': f1_m,
|
| 230 |
+
'f1_micro': f1_mi,
|
| 231 |
+
'precision_weighted': precision_w,
|
| 232 |
+
'recall_weighted': recall_w,
|
| 233 |
+
'precision_macro': precision_m,
|
| 234 |
+
'recall_macro': recall_m,
|
| 235 |
+
}
|
| 236 |
+
# per-class f1
|
| 237 |
+
per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)
|
| 238 |
+
for idx, class_name in ID2LABEL.items():
|
| 239 |
+
if idx < len(per_class_f1):
|
| 240 |
+
metrics[f'f1_class_{class_name}'] = float(per_class_f1[idx])
|
| 241 |
+
return metrics
|
| 242 |
+
|
| 243 |
+
# Training arguments for finetuning
|
| 244 |
+
# Reuse dynamic check for transformers TrainingArguments signature
|
| 245 |
+
ta_sig = signature(TrainingArguments.__init__)
|
| 246 |
+
ta_params = set(ta_sig.parameters.keys())
|
| 247 |
+
|
| 248 |
+
ft_base_kwargs = {
|
| 249 |
+
'output_dir': FT_OUTPUT_DIR,
|
| 250 |
+
'num_train_epochs': 70,
|
| 251 |
+
'per_device_train_batch_size': 32,
|
| 252 |
+
'per_device_eval_batch_size': 64,
|
| 253 |
+
'learning_rate': 1e-5,
|
| 254 |
+
'weight_decay': 0.01,
|
| 255 |
+
'warmup_ratio': 0.1,
|
| 256 |
+
'logging_steps': 50,
|
| 257 |
+
'save_total_limit': 2,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
if 'bf16' in ta_params and torch.cuda.is_available() and hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
|
| 261 |
+
ft_base_kwargs['bf16'] = True
|
| 262 |
+
elif 'fp16' in ta_params and torch.cuda.is_available():
|
| 263 |
+
ft_base_kwargs['fp16'] = True
|
| 264 |
+
|
| 265 |
+
# Add evaluation_strategy if supported
|
| 266 |
+
if 'evaluation_strategy' in ta_params:
|
| 267 |
+
ft_base_kwargs['evaluation_strategy'] = 'epoch'
|
| 268 |
+
ft_base_kwargs['save_strategy'] = 'epoch'
|
| 269 |
+
ft_base_kwargs['load_best_model_at_end'] = True
|
| 270 |
+
ft_base_kwargs['metric_for_best_model'] = 'f1_weighted'
|
| 271 |
+
|
| 272 |
+
# Filter supported args
|
| 273 |
+
ft_filtered = {k: v for k, v in ft_base_kwargs.items() if k in ta_params}
|
| 274 |
+
|
| 275 |
+
ft_training_args = TrainingArguments(**ft_filtered)
|
| 276 |
+
|
| 277 |
+
# Trainer for finetuning
|
| 278 |
+
ft_trainer = Trainer(
|
| 279 |
+
model=ft_model,
|
| 280 |
+
args=ft_training_args,
|
| 281 |
+
train_dataset=train_split,
|
| 282 |
+
eval_dataset=eval_split,
|
| 283 |
+
tokenizer=ft_tokenizer,
|
| 284 |
+
compute_metrics=compute_metrics,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
print("Starting finetuning on classification task...")
|
| 288 |
+
ft_trainer.train()
|
| 289 |
+
|
| 290 |
+
print("Finetuning finished. Saving finetuned model to:", FT_OUTPUT_DIR)
|
| 291 |
+
ft_trainer.save_model(FT_OUTPUT_DIR)
|
| 292 |
+
ft_tokenizer.save_pretrained(FT_OUTPUT_DIR)
|
| 293 |
+
|
| 294 |
+
# Update config with label mappings (so inference scripts can read cleanly)
|
| 295 |
+
config_path = os.path.join(FT_OUTPUT_DIR, 'config.json')
|
| 296 |
+
if os.path.exists(config_path):
|
| 297 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 298 |
+
cfg = json.load(f)
|
| 299 |
+
else:
|
| 300 |
+
cfg = {}
|
| 301 |
+
|
| 302 |
+
cfg['id2label'] = {str(k): str(v) for k, v in ID2LABEL.items()}
|
| 303 |
+
cfg['label2id'] = {str(v): k for k, v in LABEL2ID.items()}
|
| 304 |
+
cfg['num_labels'] = NUM_LABELS
|
| 305 |
+
cfg['problem_type'] = 'single_label_classification'
|
| 306 |
+
|
| 307 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 308 |
+
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
| 309 |
+
|
| 310 |
+
print("Saved label mappings to finetuned model config")
|
| 311 |
+
|
| 312 |
+
print('\nAll done. CPT and finetuning completed.')
|
| 313 |
+
print('CPT model saved to:', CPT_OUTPUT_DIR)
|
| 314 |
+
print('Finetuned classifier saved to:', FT_OUTPUT_DIR)
|
| 315 |
+
|
Code/finetune_from_cpt.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Fine-tuning pipeline from saved CPT model for any BERT-style model.
|
| 5 |
+
|
| 6 |
+
Loads the CPT weights and fine-tunes a classification head:
|
| 7 |
+
- Train on the training CSV (train.csv) that has `Commentaire client` and `Class` (1..9)
|
| 8 |
+
- Keep classes as 1..9 but the model uses 0..8 internally (label mapping saved to config)
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python finetune_from_cpt.py
|
| 12 |
+
|
| 13 |
+
Notes:
|
| 14 |
+
- This script uses the Hugging Face Trainer API.
|
| 15 |
+
- Adjust the epochs and batch sizes if you have a different GPU memory budget.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import json
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from transformers import (
|
| 24 |
+
AutoTokenizer,
|
| 25 |
+
AutoModelForSequenceClassification,
|
| 26 |
+
TrainingArguments,
|
| 27 |
+
Trainer,
|
| 28 |
+
)
|
| 29 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 30 |
+
from inspect import signature
|
| 31 |
+
|
| 32 |
+
# ---------------------------
|
| 33 |
+
# Paths & basic config
|
| 34 |
+
# ---------------------------
|
| 35 |
+
TRAIN_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
|
| 36 |
+
|
| 37 |
+
# Path to your saved CPT model
|
| 38 |
+
CPT_MODEL_PATH = "" # Change this to your CPT model path
|
| 39 |
+
|
| 40 |
+
# Output directory for fine-tuned model
|
| 41 |
+
FT_OUTPUT_DIR = "./telecom_arabert_large_full_pipeline"
|
| 42 |
+
|
| 43 |
+
MAX_LENGTH = 512
|
| 44 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
+
print(f"Device: {DEVICE}")
|
| 46 |
+
|
| 47 |
+
# Label mapping (keep user-facing classes 1..9 but model indices 0..8)
|
| 48 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 49 |
+
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
|
| 50 |
+
NUM_LABELS = len(LABEL2ID)
|
| 51 |
+
|
| 52 |
+
# ---------------------------
|
| 53 |
+
# Finetune classification using CPT weights
|
| 54 |
+
# ---------------------------
|
| 55 |
+
print("\n=== Finetuning phase: load CPT weights and fine-tune classifier ===\n")
|
| 56 |
+
|
| 57 |
+
# Check if CPT model exists
|
| 58 |
+
if not os.path.exists(CPT_MODEL_PATH):
|
| 59 |
+
raise FileNotFoundError(f"CPT model not found at: {CPT_MODEL_PATH}")
|
| 60 |
+
|
| 61 |
+
print(f"Loading CPT model from: {CPT_MODEL_PATH}")
|
| 62 |
+
|
| 63 |
+
# Load train dataset with labels
|
| 64 |
+
print(f"Loading training CSV from: {TRAIN_FILE}")
|
| 65 |
+
train_ds = load_dataset("csv", data_files=TRAIN_FILE, split="train")
|
| 66 |
+
print(f"Train samples: {len(train_ds)} | Columns: {train_ds.column_names}")
|
| 67 |
+
|
| 68 |
+
# Prepare label mapping on train file (ensure handling of string/int)
|
| 69 |
+
def encode_train_labels(example):
|
| 70 |
+
c = example.get("Class")
|
| 71 |
+
if isinstance(c, str):
|
| 72 |
+
try:
|
| 73 |
+
c = int(c)
|
| 74 |
+
except Exception:
|
| 75 |
+
# Attempt to strip and convert
|
| 76 |
+
c = int(c.strip())
|
| 77 |
+
if c not in LABEL2ID:
|
| 78 |
+
raise ValueError(f"Unexpected class value in training data: {c}")
|
| 79 |
+
example["labels"] = LABEL2ID[c]
|
| 80 |
+
return example
|
| 81 |
+
|
| 82 |
+
train_ds = train_ds.map(encode_train_labels)
|
| 83 |
+
|
| 84 |
+
# Train/validation split
|
| 85 |
+
split = train_ds.train_test_split(test_size=0.1, seed=42)
|
| 86 |
+
train_split = split["train"]
|
| 87 |
+
eval_split = split["test"]
|
| 88 |
+
print("Train split size:", len(train_split), "Eval split size:", len(eval_split))
|
| 89 |
+
|
| 90 |
+
# Load tokenizer and model from CPT output
|
| 91 |
+
print("Loading tokenizer and model from CPT output for finetuning...")
|
| 92 |
+
ft_tokenizer = AutoTokenizer.from_pretrained(CPT_MODEL_PATH)
|
| 93 |
+
|
| 94 |
+
# Load sequence classification model initialized from the CPT weights
|
| 95 |
+
print("Loading AutoModelForSequenceClassification from CPT weights")
|
| 96 |
+
ft_model = AutoModelForSequenceClassification.from_pretrained(
|
| 97 |
+
CPT_MODEL_PATH,
|
| 98 |
+
num_labels=NUM_LABELS,
|
| 99 |
+
id2label={str(k): str(v) for k, v in ID2LABEL.items()},
|
| 100 |
+
label2id={str(v): k for k, v in LABEL2ID.items()},
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
ft_model = ft_model.to(DEVICE)
|
| 104 |
+
|
| 105 |
+
# Print model info
|
| 106 |
+
total_params = sum(p.numel() for p in ft_model.parameters())
|
| 107 |
+
trainable_params = sum(p.numel() for p in ft_model.parameters() if p.requires_grad)
|
| 108 |
+
print(f"Total parameters: {total_params:,}")
|
| 109 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 110 |
+
|
| 111 |
+
# Tokenization function
|
| 112 |
+
def preprocess_classification(examples):
|
| 113 |
+
return ft_tokenizer(
|
| 114 |
+
examples["Commentaire client"],
|
| 115 |
+
padding="max_length",
|
| 116 |
+
truncation=True,
|
| 117 |
+
max_length=MAX_LENGTH,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
train_split = train_split.map(preprocess_classification, batched=True, num_proc=4)
|
| 121 |
+
eval_split = eval_split.map(preprocess_classification, batched=True, num_proc=4)
|
| 122 |
+
|
| 123 |
+
train_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 124 |
+
eval_split.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 125 |
+
|
| 126 |
+
# Metrics
|
| 127 |
+
def compute_metrics(eval_pred):
|
| 128 |
+
logits, labels = eval_pred
|
| 129 |
+
preds = np.argmax(logits, axis=-1)
|
| 130 |
+
acc = accuracy_score(labels, preds)
|
| 131 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
|
| 132 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
|
| 133 |
+
precision_mi, recall_mi, f1_mi, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
|
| 134 |
+
metrics = {
|
| 135 |
+
'accuracy': acc,
|
| 136 |
+
'f1_weighted': f1_w,
|
| 137 |
+
'f1_macro': f1_m,
|
| 138 |
+
'f1_micro': f1_mi,
|
| 139 |
+
'precision_weighted': precision_w,
|
| 140 |
+
'recall_weighted': recall_w,
|
| 141 |
+
'precision_macro': precision_m,
|
| 142 |
+
'recall_macro': recall_m,
|
| 143 |
+
}
|
| 144 |
+
# per-class f1
|
| 145 |
+
per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)
|
| 146 |
+
for idx, class_name in ID2LABEL.items():
|
| 147 |
+
if idx < len(per_class_f1):
|
| 148 |
+
metrics[f'f1_class_{class_name}'] = float(per_class_f1[idx])
|
| 149 |
+
return metrics
|
| 150 |
+
|
| 151 |
+
# Training arguments for finetuning
|
| 152 |
+
# Reuse dynamic check for transformers TrainingArguments signature
|
| 153 |
+
ta_sig = signature(TrainingArguments.__init__)
|
| 154 |
+
ta_params = set(ta_sig.parameters.keys())
|
| 155 |
+
|
| 156 |
+
ft_base_kwargs = {
|
| 157 |
+
'output_dir': FT_OUTPUT_DIR,
|
| 158 |
+
'num_train_epochs': 100,
|
| 159 |
+
'per_device_train_batch_size': 32,
|
| 160 |
+
'per_device_eval_batch_size': 64,
|
| 161 |
+
'learning_rate': 1e-5,
|
| 162 |
+
'weight_decay': 0.01,
|
| 163 |
+
'warmup_ratio': 0.1,
|
| 164 |
+
'logging_steps': 50,
|
| 165 |
+
'save_total_limit': 2,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if 'bf16' in ta_params and torch.cuda.is_available() and hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
|
| 169 |
+
ft_base_kwargs['bf16'] = True
|
| 170 |
+
elif 'fp16' in ta_params and torch.cuda.is_available():
|
| 171 |
+
ft_base_kwargs['fp16'] = True
|
| 172 |
+
|
| 173 |
+
# Add evaluation_strategy if supported
|
| 174 |
+
if 'evaluation_strategy' in ta_params:
|
| 175 |
+
ft_base_kwargs['evaluation_strategy'] = 'epoch'
|
| 176 |
+
ft_base_kwargs['save_strategy'] = 'epoch'
|
| 177 |
+
ft_base_kwargs['load_best_model_at_end'] = True
|
| 178 |
+
ft_base_kwargs['metric_for_best_model'] = 'f1_weighted'
|
| 179 |
+
|
| 180 |
+
# Filter supported args
|
| 181 |
+
ft_filtered = {k: v for k, v in ft_base_kwargs.items() if k in ta_params}
|
| 182 |
+
|
| 183 |
+
ft_training_args = TrainingArguments(**ft_filtered)
|
| 184 |
+
|
| 185 |
+
# Trainer for finetuning
|
| 186 |
+
ft_trainer = Trainer(
|
| 187 |
+
model=ft_model,
|
| 188 |
+
args=ft_training_args,
|
| 189 |
+
train_dataset=train_split,
|
| 190 |
+
eval_dataset=eval_split,
|
| 191 |
+
tokenizer=ft_tokenizer,
|
| 192 |
+
compute_metrics=compute_metrics,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
print("Starting finetuning on classification task...")
|
| 196 |
+
ft_trainer.train()
|
| 197 |
+
|
| 198 |
+
print("Finetuning finished. Saving finetuned model to:", FT_OUTPUT_DIR)
|
| 199 |
+
ft_trainer.save_model(FT_OUTPUT_DIR)
|
| 200 |
+
ft_tokenizer.save_pretrained(FT_OUTPUT_DIR)
|
| 201 |
+
|
| 202 |
+
# Update config with label mappings (so inference scripts can read cleanly)
|
| 203 |
+
config_path = os.path.join(FT_OUTPUT_DIR, 'config.json')
|
| 204 |
+
if os.path.exists(config_path):
|
| 205 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 206 |
+
cfg = json.load(f)
|
| 207 |
+
else:
|
| 208 |
+
cfg = {}
|
| 209 |
+
|
| 210 |
+
cfg['id2label'] = {str(k): str(v) for k, v in ID2LABEL.items()}
|
| 211 |
+
cfg['label2id'] = {str(v): k for k, v in LABEL2ID.items()}
|
| 212 |
+
cfg['num_labels'] = NUM_LABELS
|
| 213 |
+
cfg['problem_type'] = 'single_label_classification'
|
| 214 |
+
|
| 215 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 216 |
+
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
| 217 |
+
|
| 218 |
+
print("Saved label mappings to finetuned model config")
|
| 219 |
+
|
| 220 |
+
print('\nAll done. Finetuning completed.')
|
| 221 |
+
print('Finetuned classifier saved to:', FT_OUTPUT_DIR)
|
Code/finetune_gemma3_classification.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Gemma 3 4B - Instruction Fine-Tuning for Classification
|
| 5 |
+
|
| 6 |
+
Fine-tuning Gemma 3 4B with instruction format (QA style) for 9-class classification.
|
| 7 |
+
Uses RS-LoRA (Rank-Stabilized LoRA) to avoid overfitting.
|
| 8 |
+
|
| 9 |
+
Features:
|
| 10 |
+
- Text preprocessing (remove names, tatweel, emojis)
|
| 11 |
+
- Instruction tuning format with few-shot examples
|
| 12 |
+
- RS-LoRA for efficient training
|
| 13 |
+
- BF16 training on A100
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python finetune_gemma3_classification.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Suppress tokenizer fork warning
|
| 21 |
+
|
| 22 |
+
import re
|
| 23 |
+
import json
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from datasets import load_dataset, Dataset
|
| 27 |
+
from transformers import (
|
| 28 |
+
AutoTokenizer,
|
| 29 |
+
AutoModelForCausalLM,
|
| 30 |
+
TrainingArguments,
|
| 31 |
+
Trainer,
|
| 32 |
+
DataCollatorForSeq2Seq,
|
| 33 |
+
BitsAndBytesConfig,
|
| 34 |
+
)
|
| 35 |
+
from peft import (
|
| 36 |
+
LoraConfig,
|
| 37 |
+
get_peft_model,
|
| 38 |
+
TaskType,
|
| 39 |
+
prepare_model_for_kbit_training,
|
| 40 |
+
)
|
| 41 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 42 |
+
import warnings
|
| 43 |
+
warnings.filterwarnings("ignore")
|
| 44 |
+
|
| 45 |
+
# ---------------------------
|
| 46 |
+
# Paths & Config
|
| 47 |
+
# ---------------------------
|
| 48 |
+
TRAIN_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv"
|
| 49 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 50 |
+
BASE_MODEL = "google/gemma-3-4b-it"
|
| 51 |
+
|
| 52 |
+
FT_OUTPUT_DIR = "./gemma3_classification_ft"
|
| 53 |
+
|
| 54 |
+
MAX_LENGTH = 2048 # Longer for instruction format with few-shot
|
| 55 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
print(f"Device: {DEVICE}")
|
| 57 |
+
|
| 58 |
+
# Enable TF32 for A100
|
| 59 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 60 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 61 |
+
|
| 62 |
+
# Label mapping
|
| 63 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 64 |
+
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
|
| 65 |
+
NUM_LABELS = len(LABEL2ID)
|
| 66 |
+
|
| 67 |
+
text_col = "Commentaire client"
|
| 68 |
+
|
| 69 |
+
# ===========================================================================
|
| 70 |
+
# System Prompt and Few-Shot Examples
|
| 71 |
+
# ===========================================================================
|
| 72 |
+
SYSTEM_PROMPT = """You are an expert Algerian linguist and data labeler. Your task is to classify customer comments from Algérie Télécom's social media into one of 9 specific categories.
|
| 73 |
+
|
| 74 |
+
## CLASSES (DETAILED DESCRIPTIONS)
|
| 75 |
+
- **Class 1 (Wish/Positive Anticipation):** Comments expressing a wish, a hopeful anticipation, or general positive feedback/appreciation for future services or offers.
|
| 76 |
+
- **Class 2 (Complaint: Equipment/Supply):** Comments complaining about the lack, unavailability, or delay in the supply of necessary equipment (e.g., modems, fiber optics devices).
|
| 77 |
+
- **Class 3 (Complaint: Marketing/Advertising):** Comments criticizing advertisements, marketing campaigns, or their lack of realism/meaning.
|
| 78 |
+
- **Class 4 (Complaint: Installation/Deployment):** Comments about delays, stoppages, or failure in service installation, network expansion, or fiber optics deployment (e.g., digging issues).
|
| 79 |
+
- **Class 5 (Inquiry/Request for Information):** Comments asking for eligibility, connection dates, service status, coverage details, or specific contact information.
|
| 80 |
+
- **Class 6 (Complaint: Technical Support/Intervention):** Comments regarding delays in repair interventions, issues with technical staff competence, or unsatisfactory customer service agency visits.
|
| 81 |
+
- **Class 7 (Pricing/Service Enhancement):** Comments focused on pricing, requests for cost reduction, or suggestions for general service/app functionality enhancements.
|
| 82 |
+
- **Class 8 (Complaint: Total Service Outage/Disconnection):** Comments indicating a complete, sustained loss of service (e.g., no phone, no internet, total disconnection).
|
| 83 |
+
- **Class 9 (Complaint: Service Performance/Quality):** Comments about technical issues impacting performance (e.g., slow speed, high latency, broken website/portal, coverage claims).
|
| 84 |
+
|
| 85 |
+
Respond with ONLY the class number (1-9). Do not include any explanation."""
|
| 86 |
+
|
| 87 |
+
# Few-shot examples (2-3 per class for diversity)
|
| 88 |
+
FEW_SHOT_EXAMPLES = [
|
| 89 |
+
# Class 1
|
| 90 |
+
{"comment": "إن شاء الله يكون عرض صحاب 300 و 500 ميجا فيبر ياربي", "class": "1"},
|
| 91 |
+
{"comment": "اتمنى لكم مزيد من التألق", "class": "1"},
|
| 92 |
+
# Class 2
|
| 93 |
+
{"comment": "زعما جابو المودام ؟", "class": "2"},
|
| 94 |
+
{"comment": "وفرو أجهزة مودام الباقي ساهل !", "class": "2"},
|
| 95 |
+
# Class 3
|
| 96 |
+
{"comment": "إشهار بدون معنه", "class": "3"},
|
| 97 |
+
# Class 4
|
| 98 |
+
{"comment": "المشروع متوقف منذ اشهر", "class": "4"},
|
| 99 |
+
{"comment": "نتمنى تكملو في ايسطو وهران في اقرب وقت رانا نعانو مع ADSL", "class": "4"},
|
| 100 |
+
# Class 5
|
| 101 |
+
{"comment": "modem", "class": "5"},
|
| 102 |
+
{"comment": "يعني كي نطلعها ثلاثون ميغا كارطة تاع مائة الف ��داه تحكملي؟", "class": "5"},
|
| 103 |
+
# Class 6
|
| 104 |
+
{"comment": "عرض 20 ميجا نحيوه مدام مش قادرين تعطيونا حقنا", "class": "6"},
|
| 105 |
+
# Class 7
|
| 106 |
+
{"comment": "نقصوا الاسعار بزااااف غالية", "class": "7"},
|
| 107 |
+
{"comment": "علاه ماديروش في التطبيق خاصية التوقيف المؤقت للانترانات", "class": "7"},
|
| 108 |
+
# Class 8
|
| 109 |
+
{"comment": "رانا بلا تلفون ولا انترنت", "class": "8"},
|
| 110 |
+
{"comment": "ثلاثة اشهر بلا انترنت", "class": "8"},
|
| 111 |
+
# Class 9
|
| 112 |
+
{"comment": "فضاء الزبون علاه منقدروش نسجلو فيه", "class": "9"},
|
| 113 |
+
{"comment": "هل موقع فضاء الزبون متوقف", "class": "9"},
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# ===========================================================================
|
| 117 |
+
# Text Preprocessing
|
| 118 |
+
# ===========================================================================
|
| 119 |
+
def preprocess_text(text):
|
| 120 |
+
"""
|
| 121 |
+
Preprocess text:
|
| 122 |
+
- Remove Arabic tatweel (ـ)
|
| 123 |
+
- Remove emojis
|
| 124 |
+
- Remove user mentions/names
|
| 125 |
+
- Normalize whitespace
|
| 126 |
+
- Remove phone numbers
|
| 127 |
+
- Remove URLs
|
| 128 |
+
"""
|
| 129 |
+
if not isinstance(text, str):
|
| 130 |
+
return ""
|
| 131 |
+
|
| 132 |
+
# Remove URLs
|
| 133 |
+
text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
| 134 |
+
|
| 135 |
+
# Remove email addresses
|
| 136 |
+
text = re.sub(r'\S+@\S+', '', text)
|
| 137 |
+
|
| 138 |
+
# Remove phone numbers (various formats)
|
| 139 |
+
text = re.sub(r'[\+]?[(]?[0-9]{1,4}[)]?[-\s\./0-9]{6,}', '', text)
|
| 140 |
+
text = re.sub(r'\b0[567]\d{8}\b', '', text) # Algerian mobile
|
| 141 |
+
text = re.sub(r'\b0[23]\d{7,8}\b', '', text) # Algerian landline
|
| 142 |
+
|
| 143 |
+
# Remove mentions (@username)
|
| 144 |
+
text = re.sub(r'@\w+', '', text)
|
| 145 |
+
|
| 146 |
+
# Remove Arabic tatweel (kashida)
|
| 147 |
+
text = re.sub(r'ـ+', '', text)
|
| 148 |
+
|
| 149 |
+
# Remove emojis and other symbols
|
| 150 |
+
emoji_pattern = re.compile("["
|
| 151 |
+
u"\U0001F600-\U0001F64F" # emoticons
|
| 152 |
+
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
| 153 |
+
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
| 154 |
+
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
| 155 |
+
u"\U00002702-\U000027B0"
|
| 156 |
+
u"\U000024C2-\U0001F251"
|
| 157 |
+
u"\U0001f926-\U0001f937"
|
| 158 |
+
u"\U00010000-\U0010ffff"
|
| 159 |
+
u"\u2640-\u2642"
|
| 160 |
+
u"\u2600-\u2B55"
|
| 161 |
+
u"\u200d"
|
| 162 |
+
u"\u23cf"
|
| 163 |
+
u"\u23e9"
|
| 164 |
+
u"\u231a"
|
| 165 |
+
u"\ufe0f"
|
| 166 |
+
u"\u3030"
|
| 167 |
+
"]+", flags=re.UNICODE)
|
| 168 |
+
text = emoji_pattern.sub('', text)
|
| 169 |
+
|
| 170 |
+
# Remove common platform names that might be mentioned
|
| 171 |
+
text = re.sub(r'Algérie Télécom - إتصالات الجزائر', '', text, flags=re.IGNORECASE)
|
| 172 |
+
text = re.sub(r'Algérie Télécom', '', text, flags=re.IGNORECASE)
|
| 173 |
+
text = re.sub(r'إتصالات الجزائر', '', text)
|
| 174 |
+
|
| 175 |
+
# Remove repeated characters (more than 3)
|
| 176 |
+
text = re.sub(r'(.)\1{3,}', r'\1\1\1', text)
|
| 177 |
+
|
| 178 |
+
# Normalize whitespace
|
| 179 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 180 |
+
|
| 181 |
+
return text
|
| 182 |
+
|
| 183 |
+
# ===========================================================================
|
| 184 |
+
# Format Data for Instruction Tuning
|
| 185 |
+
# ===========================================================================
|
| 186 |
+
# Hardcoded few-shot examples string (not from data)
|
| 187 |
+
FEW_SHOT_STRING = """
|
| 188 |
+
Comment: إن شاء الله يكون عرض صحاب 300 و 500 ميجا فيبر ياربي
|
| 189 |
+
Class: 1
|
| 190 |
+
|
| 191 |
+
Comment: الف مبروووك..
|
| 192 |
+
Class: 1
|
| 193 |
+
|
| 194 |
+
Comment: - إتصالات الجزائر شكرا اتمنى لكم دوام الصحة والعافية
|
| 195 |
+
Class: 1
|
| 196 |
+
|
| 197 |
+
Comment: C une fierté de faire partie de cette grande entreprise Algérienne de haute technologie et haute qualité
|
| 198 |
+
Class: 1
|
| 199 |
+
|
| 200 |
+
Comment: اتمنى لكم مزيد من التألق
|
| 201 |
+
Class: 1
|
| 202 |
+
|
| 203 |
+
Comment: زعما جابو المودام ؟
|
| 204 |
+
Class: 2
|
| 205 |
+
|
| 206 |
+
Comment: وفرو أجهزة مودام الباقي ساهل !
|
| 207 |
+
Class: 2
|
| 208 |
+
|
| 209 |
+
Comment: واش الفايدة تع العرض هذا هو اصلا لي مودام مهوش متوفر رنا قريب عام وحنا ستناو في جد موام هذا
|
| 210 |
+
Class: 2
|
| 211 |
+
|
| 212 |
+
Comment: Depuis un an et demi qu'on a installé w ma kan walou
|
| 213 |
+
Class: 2
|
| 214 |
+
|
| 215 |
+
Comment: قتلتونا بلكذب المودام غير متوفر عندي 4 أشهر ملي حطيت الطلب في ولاية خنشلة و مزال ماجابوش المودام
|
| 216 |
+
Class: 2
|
| 217 |
+
|
| 218 |
+
Comment: عندكم احساس و لا شريوه كما قالو خوتنا لمصريين
|
| 219 |
+
Class: 3
|
| 220 |
+
|
| 221 |
+
Comment: Kamel Dahmane الفايبر؟ مستحيل كامل عاجبتهم
|
| 222 |
+
Class: 3
|
| 223 |
+
|
| 224 |
+
Comment: ههههه نخلص مليون عادي كون يركبونا الفيبر 😂😂😂😂😂 كرهنا من 144p
|
| 225 |
+
Class: 3
|
| 226 |
+
|
| 227 |
+
Comment: إشهار بدون معنه
|
| 228 |
+
Class: 3
|
| 229 |
+
|
| 230 |
+
Comment: المشروع متوقف منذ اشهر
|
| 231 |
+
Class: 4
|
| 232 |
+
|
| 233 |
+
Comment: نتمنى تكملو في ايسطو وهران في اقرب وقت رانا نعانو مع ADSL
|
| 234 |
+
Class: 4
|
| 235 |
+
|
| 236 |
+
Comment: Fibre كاش واحد وصلوله الفيبر؟
|
| 237 |
+
Class: 4
|
| 238 |
+
|
| 239 |
+
Comment: ما هو الجديد وانا مزال ماعنديش الفيبر رغم الطلب ولالحاح
|
| 240 |
+
Class: 4
|
| 241 |
+
|
| 242 |
+
Comment: علبة الفيبر راكبة في الحي و لكن لا يوجد توصيل للمنزل للان
|
| 243 |
+
Class: 4
|
| 244 |
+
|
| 245 |
+
Comment: modem
|
| 246 |
+
Class: 5
|
| 247 |
+
|
| 248 |
+
Comment: يعني كي نطلعها ثلاثون ميغا كارطة تاع مائة الف قداه تحكملي؟
|
| 249 |
+
Class: 5
|
| 250 |
+
|
| 251 |
+
Comment: سآل الأماكن لي ما فيهاش الألياف البصرية إذا جابولنا الألياف السرعة تكون محدودة كيما ف ADSL؟
|
| 252 |
+
Class: 5
|
| 253 |
+
|
| 254 |
+
Comment: ماعرف كاش خبر على ايدوم 4G ماعرف تبقى قرد العش
|
| 255 |
+
Class: 5
|
| 256 |
+
|
| 257 |
+
Comment: هل متوفرة في حي عدل 1046 مسكن دويرة
|
| 258 |
+
Class: 5
|
| 259 |
+
|
| 260 |
+
Comment: عرض 20 ميجا نحيوه مدام مش قادرين تعطيونا حقنا
|
| 261 |
+
Class: 6
|
| 262 |
+
|
| 263 |
+
Comment: 4 سنوات وحنا نخلصو فالدار ماشفنا حتى bonus
|
| 264 |
+
Class: 6
|
| 265 |
+
|
| 266 |
+
Comment: لماذا التغيير في الرقم بدون تغيير سرعة التدفق هل من أجل الإشهار وفقط انا غير من 50 ميغا إلا 200 ميغا نظريا تغيرت وفي الواقع بقت قياس أقل من 50 ميغا
|
| 267 |
+
Class: 6
|
| 268 |
+
|
| 269 |
+
Comment: انا طلعت تدفق انترنات من 15 الى 20 عبر تطبيق my idoom لاكن سرعة لم تتغير
|
| 270 |
+
Class: 6
|
| 271 |
+
|
| 272 |
+
Comment: نقصوا الاسعار بزااااف غالية
|
| 273 |
+
Class: 7
|
| 274 |
+
|
| 275 |
+
Comment: علاه ماديروش في التطبيق خاصية التوقيف المؤقت للانترانات
|
| 276 |
+
Class: 7
|
| 277 |
+
|
| 278 |
+
Comment: وفرونا من بعد اي ساهلة
|
| 279 |
+
Class: 7
|
| 280 |
+
|
| 281 |
+
Comment: لازم ترجعو اتصال بتطبيقات الدفع بلا انترنت و مجاني ريقلوها يا اتصالات الجزائر
|
| 282 |
+
Class: 7
|
| 283 |
+
|
| 284 |
+
Comment: Promotion fin d'année ADSL idoom
|
| 285 |
+
Class: 7
|
| 286 |
+
|
| 287 |
+
Comment: رانا بلا تلفون ولا انترنت
|
| 288 |
+
Class: 8
|
| 289 |
+
|
| 290 |
+
Comment: ثلاثة اشهر بلا انترنت
|
| 291 |
+
Class: 8
|
| 292 |
+
|
| 293 |
+
Comment: votre site espace client ne fonctionne pas pourquoi?
|
| 294 |
+
Class: 8
|
| 295 |
+
|
| 296 |
+
Comment: ما عندنا الانترنيت ما نخلصوها من الدار
|
| 297 |
+
Class: 8
|
| 298 |
+
|
| 299 |
+
Comment: مشكل في 1.200جيق فيبر مدام نوكيا مخرج الانترنت 1جيق فقط كفاش راح تحلو هذا مشكل ومشكل ثاني فضاء الزبون ميمشيش مندو شهر
|
| 300 |
+
Class: 8
|
| 301 |
+
|
| 302 |
+
Comment: فضاء الزبون علاه منقدروش نسجلو فيه
|
| 303 |
+
Class: 9
|
| 304 |
+
|
| 305 |
+
Comment: هل موقع فضاء الزبون متوقف
|
| 306 |
+
Class: 9
|
| 307 |
+
|
| 308 |
+
Comment: ماراهيش توصل الفاتورة لا عن طريق الإيميل ولا عن طريق فضاء الزبون
|
| 309 |
+
Class: 9
|
| 310 |
+
|
| 311 |
+
Comment: فضاء الزبون قرابة 20 يوم متوقف!!!!!!؟؟؟؟؟
|
| 312 |
+
Class: 9
|
| 313 |
+
|
| 314 |
+
Comment: برج الكيفان اظنها من العاصمة خارج تغطيتكم....احشموا بركاو بلا كذب....طلعنا الصواريخ للفضاء....بصح بالكذب....
|
| 315 |
+
Class: 9"""
|
| 316 |
+
|
| 317 |
+
def format_few_shot_prompt(comment):
|
| 318 |
+
"""Format prompt with few-shot examples for classification."""
|
| 319 |
+
# Build the user prompt with hardcoded few-shot examples
|
| 320 |
+
user_prompt = f"""Here are some examples of how to classify comments:
|
| 321 |
+
|
| 322 |
+
{FEW_SHOT_STRING}
|
| 323 |
+
|
| 324 |
+
Now classify this comment:
|
| 325 |
+
Comment: {comment}
|
| 326 |
+
Class:"""
|
| 327 |
+
|
| 328 |
+
return user_prompt
|
| 329 |
+
|
| 330 |
+
def create_instruction_format(example, tokenizer, is_train=True):
|
| 331 |
+
"""
|
| 332 |
+
Create instruction format for Gemma 3.
|
| 333 |
+
|
| 334 |
+
For training: includes the answer
|
| 335 |
+
For inference: no answer
|
| 336 |
+
"""
|
| 337 |
+
comment = preprocess_text(example.get(text_col, ""))
|
| 338 |
+
|
| 339 |
+
# Build conversation
|
| 340 |
+
messages = [
|
| 341 |
+
{"role": "user", "content": SYSTEM_PROMPT + "\n\n" + format_few_shot_prompt(comment)}
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
if is_train:
|
| 345 |
+
label = example.get("Class", example.get("labels", 1))
|
| 346 |
+
if isinstance(label, str):
|
| 347 |
+
label = int(label.strip())
|
| 348 |
+
# Add assistant response (just the class number)
|
| 349 |
+
messages.append({"role": "assistant", "content": str(label)})
|
| 350 |
+
|
| 351 |
+
# Apply chat template
|
| 352 |
+
if is_train:
|
| 353 |
+
# For training, we need the full conversation
|
| 354 |
+
text = tokenizer.apply_chat_template(
|
| 355 |
+
messages,
|
| 356 |
+
tokenize=False,
|
| 357 |
+
add_generation_prompt=False
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
# For inference, add generation prompt
|
| 361 |
+
text = tokenizer.apply_chat_template(
|
| 362 |
+
messages,
|
| 363 |
+
tokenize=False,
|
| 364 |
+
add_generation_prompt=True
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return text
|
| 368 |
+
|
| 369 |
+
def prepare_train_dataset(dataset, tokenizer):
|
| 370 |
+
"""Prepare training dataset with instruction format.
|
| 371 |
+
|
| 372 |
+
Only compute loss on the assistant's response (class number),
|
| 373 |
+
not on the prompt. This is done by setting labels to -100 for prompt tokens.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
def process_example(example):
|
| 377 |
+
# Get the full text with answer
|
| 378 |
+
full_text = create_instruction_format(example, tokenizer, is_train=True)
|
| 379 |
+
|
| 380 |
+
# Get the prompt only (without answer) to find where response starts
|
| 381 |
+
prompt_text = create_instruction_format(example, tokenizer, is_train=False)
|
| 382 |
+
|
| 383 |
+
# Tokenize both
|
| 384 |
+
full_tokenized = tokenizer(
|
| 385 |
+
full_text,
|
| 386 |
+
truncation=True,
|
| 387 |
+
max_length=MAX_LENGTH,
|
| 388 |
+
padding=False,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
prompt_tokenized = tokenizer(
|
| 392 |
+
prompt_text,
|
| 393 |
+
truncation=True,
|
| 394 |
+
max_length=MAX_LENGTH,
|
| 395 |
+
padding=False,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Create labels: -100 for prompt tokens (ignored in loss), actual ids for response
|
| 399 |
+
prompt_len = len(prompt_tokenized["input_ids"])
|
| 400 |
+
labels = [-100] * prompt_len + full_tokenized["input_ids"][prompt_len:]
|
| 401 |
+
|
| 402 |
+
# Ensure labels has same length as input_ids
|
| 403 |
+
if len(labels) < len(full_tokenized["input_ids"]):
|
| 404 |
+
labels = labels + full_tokenized["input_ids"][len(labels):]
|
| 405 |
+
elif len(labels) > len(full_tokenized["input_ids"]):
|
| 406 |
+
labels = labels[:len(full_tokenized["input_ids"])]
|
| 407 |
+
|
| 408 |
+
full_tokenized["labels"] = labels
|
| 409 |
+
|
| 410 |
+
return full_tokenized
|
| 411 |
+
|
| 412 |
+
return dataset.map(process_example, remove_columns=dataset.column_names)
|
| 413 |
+
|
| 414 |
+
# ===========================================================================
|
| 415 |
+
# RS-LoRA Configuration (Rank-Stabilized LoRA)
|
| 416 |
+
# ===========================================================================
|
| 417 |
+
RS_LORA_CONFIG = {
|
| 418 |
+
"r": 64, # LoRA rank
|
| 419 |
+
"lora_alpha": 64, # For RS-LoRA, alpha = r (rank-stabilized)
|
| 420 |
+
"lora_dropout": 0.05, # Dropout for regularization
|
| 421 |
+
"target_modules": [ # Gemma attention/MLP modules
|
| 422 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 423 |
+
"gate_proj", "up_proj", "down_proj",
|
| 424 |
+
],
|
| 425 |
+
"use_rslora": True, # Enable RS-LoRA
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
# Fine-Tuning Config
|
| 429 |
+
FT_CONFIG = {
|
| 430 |
+
"num_epochs": 3,
|
| 431 |
+
"batch_size": 4,
|
| 432 |
+
"gradient_accumulation_steps": 8, # Effective batch = 32
|
| 433 |
+
"learning_rate": 2e-4,
|
| 434 |
+
"weight_decay": 0.01,
|
| 435 |
+
"warmup_ratio": 0.1,
|
| 436 |
+
"max_grad_norm": 1.0,
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
# ===========================================================================
|
| 440 |
+
# Main Training
|
| 441 |
+
# ===========================================================================
|
| 442 |
+
print("\n" + "="*70)
|
| 443 |
+
print("Gemma 3 4B - Instruction Fine-Tuning for Classification")
|
| 444 |
+
print("="*70 + "\n")
|
| 445 |
+
|
| 446 |
+
# Load tokenizer
|
| 447 |
+
print(f"Loading tokenizer from: {BASE_MODEL}")
|
| 448 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
| 449 |
+
|
| 450 |
+
# Set padding token
|
| 451 |
+
if tokenizer.pad_token is None:
|
| 452 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 453 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 454 |
+
|
| 455 |
+
# Set padding side for causal LM
|
| 456 |
+
tokenizer.padding_side = "right"
|
| 457 |
+
|
| 458 |
+
# Load model
|
| 459 |
+
print(f"Loading model from: {BASE_MODEL}")
|
| 460 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 461 |
+
BASE_MODEL,
|
| 462 |
+
torch_dtype=torch.bfloat16,
|
| 463 |
+
trust_remote_code=True,
|
| 464 |
+
device_map="auto",
|
| 465 |
+
attn_implementation="eager", # Use eager attention for compatibility
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Apply RS-LoRA
|
| 469 |
+
print("\nApplying RS-LoRA configuration...")
|
| 470 |
+
lora_config = LoraConfig(
|
| 471 |
+
task_type=TaskType.CAUSAL_LM,
|
| 472 |
+
r=RS_LORA_CONFIG["r"],
|
| 473 |
+
lora_alpha=RS_LORA_CONFIG["lora_alpha"],
|
| 474 |
+
lora_dropout=RS_LORA_CONFIG["lora_dropout"],
|
| 475 |
+
target_modules=RS_LORA_CONFIG["target_modules"],
|
| 476 |
+
bias="none",
|
| 477 |
+
use_rslora=RS_LORA_CONFIG["use_rslora"], # RS-LoRA for better stability
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
model = get_peft_model(model, lora_config)
|
| 481 |
+
model.print_trainable_parameters()
|
| 482 |
+
|
| 483 |
+
# Load training data
|
| 484 |
+
print(f"\nLoading training data from: {TRAIN_FILE}")
|
| 485 |
+
train_ds = load_dataset("csv", data_files=TRAIN_FILE, split="train")
|
| 486 |
+
print(f"Total training samples: {len(train_ds)}")
|
| 487 |
+
|
| 488 |
+
# Preprocess and check data
|
| 489 |
+
print("\nPreprocessing text data...")
|
| 490 |
+
def preprocess_dataset(example):
|
| 491 |
+
example["clean_text"] = preprocess_text(example.get(text_col, ""))
|
| 492 |
+
return example
|
| 493 |
+
|
| 494 |
+
train_ds = train_ds.map(preprocess_dataset)
|
| 495 |
+
|
| 496 |
+
# Show preprocessing examples
|
| 497 |
+
print("\nPreprocessing examples:")
|
| 498 |
+
for i in range(min(3, len(train_ds))):
|
| 499 |
+
original = train_ds[i].get(text_col, "")[:80]
|
| 500 |
+
cleaned = train_ds[i].get("clean_text", "")[:80]
|
| 501 |
+
print(f" Original: {original}...")
|
| 502 |
+
print(f" Cleaned: {cleaned}...")
|
| 503 |
+
print()
|
| 504 |
+
|
| 505 |
+
# Train/val split
|
| 506 |
+
split = train_ds.train_test_split(test_size=0.01, seed=42)
|
| 507 |
+
train_split = split["train"]
|
| 508 |
+
eval_split = split["test"]
|
| 509 |
+
print(f"Train split: {len(train_split)} | Eval split: {len(eval_split)}")
|
| 510 |
+
|
| 511 |
+
# Prepare datasets
|
| 512 |
+
print("\nPreparing instruction-formatted datasets...")
|
| 513 |
+
train_dataset = prepare_train_dataset(train_split, tokenizer)
|
| 514 |
+
eval_dataset = prepare_train_dataset(eval_split, tokenizer)
|
| 515 |
+
|
| 516 |
+
# Show example formatted input
|
| 517 |
+
print("\nExample formatted input (truncated):")
|
| 518 |
+
example_text = create_instruction_format(train_split[0], tokenizer, is_train=True)
|
| 519 |
+
print(example_text[:500] + "..." if len(example_text) > 500 else example_text)
|
| 520 |
+
|
| 521 |
+
# Data collator
|
| 522 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 523 |
+
tokenizer=tokenizer,
|
| 524 |
+
padding=True,
|
| 525 |
+
return_tensors="pt",
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Training arguments
|
| 529 |
+
print("\n--- Fine-Tuning Hyperparameters ---")
|
| 530 |
+
for k, v in FT_CONFIG.items():
|
| 531 |
+
print(f" {k}: {v}")
|
| 532 |
+
print(f"\n--- RS-LoRA Configuration ---")
|
| 533 |
+
print(f" rank: {RS_LORA_CONFIG['r']}")
|
| 534 |
+
print(f" alpha: {RS_LORA_CONFIG['lora_alpha']}")
|
| 535 |
+
print(f" dropout: {RS_LORA_CONFIG['lora_dropout']}")
|
| 536 |
+
print(f" use_rslora: {RS_LORA_CONFIG['use_rslora']}")
|
| 537 |
+
|
| 538 |
+
training_args = TrainingArguments(
|
| 539 |
+
output_dir=FT_OUTPUT_DIR,
|
| 540 |
+
num_train_epochs=FT_CONFIG["num_epochs"],
|
| 541 |
+
per_device_train_batch_size=FT_CONFIG["batch_size"],
|
| 542 |
+
per_device_eval_batch_size=FT_CONFIG["batch_size"],
|
| 543 |
+
gradient_accumulation_steps=FT_CONFIG["gradient_accumulation_steps"],
|
| 544 |
+
learning_rate=FT_CONFIG["learning_rate"],
|
| 545 |
+
weight_decay=FT_CONFIG["weight_decay"],
|
| 546 |
+
warmup_ratio=FT_CONFIG["warmup_ratio"],
|
| 547 |
+
max_grad_norm=FT_CONFIG["max_grad_norm"],
|
| 548 |
+
bf16=True,
|
| 549 |
+
logging_steps=10,
|
| 550 |
+
eval_strategy="epoch",
|
| 551 |
+
save_strategy="epoch",
|
| 552 |
+
save_total_limit=2,
|
| 553 |
+
load_best_model_at_end=True,
|
| 554 |
+
metric_for_best_model="eval_loss",
|
| 555 |
+
greater_is_better=False,
|
| 556 |
+
dataloader_num_workers=4,
|
| 557 |
+
report_to="none",
|
| 558 |
+
gradient_checkpointing=True,
|
| 559 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Trainer
|
| 563 |
+
trainer = Trainer(
|
| 564 |
+
model=model,
|
| 565 |
+
args=training_args,
|
| 566 |
+
train_dataset=train_dataset,
|
| 567 |
+
eval_dataset=eval_dataset,
|
| 568 |
+
tokenizer=tokenizer,
|
| 569 |
+
data_collator=data_collator,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
print("\nStarting fine-tuning...")
|
| 573 |
+
trainer.train()
|
| 574 |
+
|
| 575 |
+
print(f"\nSaving model to: {FT_OUTPUT_DIR}")
|
| 576 |
+
trainer.save_model(FT_OUTPUT_DIR)
|
| 577 |
+
tokenizer.save_pretrained(FT_OUTPUT_DIR)
|
| 578 |
+
|
| 579 |
+
# Save config
|
| 580 |
+
config = {
|
| 581 |
+
"base_model": BASE_MODEL,
|
| 582 |
+
"num_labels": NUM_LABELS,
|
| 583 |
+
"id2label": ID2LABEL,
|
| 584 |
+
"label2id": LABEL2ID,
|
| 585 |
+
"rs_lora_config": RS_LORA_CONFIG,
|
| 586 |
+
"ft_config": FT_CONFIG,
|
| 587 |
+
}
|
| 588 |
+
with open(os.path.join(FT_OUTPUT_DIR, "training_config.json"), "w") as f:
|
| 589 |
+
json.dump(config, f, indent=2)
|
| 590 |
+
|
| 591 |
+
# ===========================================================================
|
| 592 |
+
# Inference on Test Set
|
| 593 |
+
# ===========================================================================
|
| 594 |
+
print("\n" + "="*70)
|
| 595 |
+
print("Inference on Test Set")
|
| 596 |
+
print("="*70 + "\n")
|
| 597 |
+
|
| 598 |
+
# Load test data
|
| 599 |
+
test_ds = load_dataset("csv", data_files=TEST_FILE, split="train")
|
| 600 |
+
print(f"Test samples: {len(test_ds)}")
|
| 601 |
+
|
| 602 |
+
# Preprocess test data
|
| 603 |
+
test_ds = test_ds.map(preprocess_dataset)
|
| 604 |
+
|
| 605 |
+
# Run inference
|
| 606 |
+
print("Running inference...")
|
| 607 |
+
model.eval()
|
| 608 |
+
|
| 609 |
+
all_preds = []
|
| 610 |
+
batch_size = 1 # Process one at a time for generation
|
| 611 |
+
|
| 612 |
+
from tqdm import tqdm
|
| 613 |
+
|
| 614 |
+
for i in tqdm(range(len(test_ds)), desc="Predicting"):
|
| 615 |
+
example = test_ds[i]
|
| 616 |
+
|
| 617 |
+
# Create prompt (without answer)
|
| 618 |
+
prompt = create_instruction_format(example, tokenizer, is_train=False)
|
| 619 |
+
|
| 620 |
+
# Tokenize
|
| 621 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
|
| 622 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 623 |
+
|
| 624 |
+
# Generate
|
| 625 |
+
with torch.no_grad():
|
| 626 |
+
outputs = model.generate(
|
| 627 |
+
**inputs,
|
| 628 |
+
max_new_tokens=5,
|
| 629 |
+
do_sample=False,
|
| 630 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 631 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Decode only the new tokens
|
| 635 |
+
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 636 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 637 |
+
|
| 638 |
+
# Extract class number
|
| 639 |
+
try:
|
| 640 |
+
# Try to extract first number found
|
| 641 |
+
match = re.search(r'\b([1-9])\b', generated_text)
|
| 642 |
+
if match:
|
| 643 |
+
pred_class = int(match.group(1))
|
| 644 |
+
else:
|
| 645 |
+
pred_class = 1 # Default
|
| 646 |
+
except:
|
| 647 |
+
pred_class = 1 # Default
|
| 648 |
+
|
| 649 |
+
all_preds.append(pred_class)
|
| 650 |
+
|
| 651 |
+
# Save predictions
|
| 652 |
+
import pandas as pd
|
| 653 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 654 |
+
test_df["Predicted_Class"] = all_preds
|
| 655 |
+
|
| 656 |
+
output_file = "test_predictions_gemma3.csv"
|
| 657 |
+
test_df.to_csv(output_file, index=False)
|
| 658 |
+
print(f"\nPredictions saved to: {output_file}")
|
| 659 |
+
|
| 660 |
+
# Show sample predictions
|
| 661 |
+
print("\nSample predictions:")
|
| 662 |
+
for i in range(min(10, len(test_df))):
|
| 663 |
+
text = str(test_df.iloc[i][text_col])[:60] + "..." if len(str(test_df.iloc[i][text_col])) > 60 else str(test_df.iloc[i][text_col])
|
| 664 |
+
pred = test_df.iloc[i]["Predicted_Class"]
|
| 665 |
+
print(f" [{i+1}] Class {pred}: {text}")
|
| 666 |
+
|
| 667 |
+
# Class distribution
|
| 668 |
+
print("\nPrediction distribution:")
|
| 669 |
+
pred_counts = test_df["Predicted_Class"].value_counts().sort_index()
|
| 670 |
+
for class_label, count in pred_counts.items():
|
| 671 |
+
print(f" Class {class_label}: {count} samples ({count/len(test_df)*100:.1f}%)")
|
| 672 |
+
|
| 673 |
+
# ===========================================================================
|
| 674 |
+
# Summary
|
| 675 |
+
# ===========================================================================
|
| 676 |
+
print("\n" + "="*70)
|
| 677 |
+
print("TRAINING COMPLETE!")
|
| 678 |
+
print("="*70)
|
| 679 |
+
print(f"\nBase Model: {BASE_MODEL}")
|
| 680 |
+
print(f"Fine-tuned model saved to: {FT_OUTPUT_DIR}")
|
| 681 |
+
print(f"Predictions saved to: {output_file}")
|
| 682 |
+
print(f"\nTraining samples: {len(train_split)}")
|
| 683 |
+
print(f"Validation samples: {len(eval_split)}")
|
| 684 |
+
print(f"Test samples: {len(test_df)}")
|
| 685 |
+
print(f"RS-LoRA rank: {RS_LORA_CONFIG['r']}")
|
| 686 |
+
print(f"Use RS-LoRA: {RS_LORA_CONFIG['use_rslora']}")
|
Code/inference_camelbert.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Inference script for CPT+finetuned MARBERTv2 telecom classification model.
|
| 6 |
+
|
| 7 |
+
Loads the model from ./telecom_marbertv2_cpt_ft and runs predictions on test.csv
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForSequenceClassification,
|
| 18 |
+
AutoConfig,
|
| 19 |
+
)
|
| 20 |
+
from sklearn.metrics import (
|
| 21 |
+
accuracy_score,
|
| 22 |
+
f1_score,
|
| 23 |
+
precision_recall_fscore_support,
|
| 24 |
+
classification_report,
|
| 25 |
+
confusion_matrix,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# -------------------------------------------------------------------
|
| 29 |
+
# 1. Paths & config
|
| 30 |
+
# -------------------------------------------------------------------
|
| 31 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 32 |
+
MODEL_DIR = "./telecom_camelbert_cpt_ft"
|
| 33 |
+
OUTPUT_FILE = "./test_predictions_camelbert_cpt_ft.csv"
|
| 34 |
+
|
| 35 |
+
MAX_LENGTH = 256
|
| 36 |
+
BATCH_SIZE = 64
|
| 37 |
+
|
| 38 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
print(f"Using device: {device}")
|
| 40 |
+
|
| 41 |
+
# -------------------------------------------------------------------
|
| 42 |
+
# 2. Load test data
|
| 43 |
+
# -------------------------------------------------------------------
|
| 44 |
+
print(f"Loading test data from: {TEST_FILE}")
|
| 45 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 46 |
+
print(f"Test samples: {len(test_df)}")
|
| 47 |
+
print(f"Columns: {test_df.columns.tolist()}")
|
| 48 |
+
|
| 49 |
+
# Check if test data has labels
|
| 50 |
+
has_labels = "Class" in test_df.columns
|
| 51 |
+
if has_labels:
|
| 52 |
+
print("Test data contains labels - will compute metrics")
|
| 53 |
+
else:
|
| 54 |
+
print("Test data has no labels - will only generate predictions")
|
| 55 |
+
|
| 56 |
+
# -------------------------------------------------------------------
|
| 57 |
+
# 3. Load model and tokenizer
|
| 58 |
+
# -------------------------------------------------------------------
|
| 59 |
+
print(f"\nLoading model from: {MODEL_DIR}")
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 61 |
+
|
| 62 |
+
# Load config to get label mappings
|
| 63 |
+
config_path = os.path.join(MODEL_DIR, "config.json")
|
| 64 |
+
if os.path.exists(config_path):
|
| 65 |
+
import json
|
| 66 |
+
with open(config_path, 'r') as f:
|
| 67 |
+
config_data = json.load(f)
|
| 68 |
+
|
| 69 |
+
if 'id2label' in config_data:
|
| 70 |
+
id2label = {int(k): int(v) for k, v in config_data['id2label'].items()}
|
| 71 |
+
# Create label2id with both string and int keys for robustness
|
| 72 |
+
label2id = {}
|
| 73 |
+
for k, v in id2label.items():
|
| 74 |
+
label2id[v] = k # int key -> int value
|
| 75 |
+
label2id[str(v)] = k # string key -> int value
|
| 76 |
+
num_labels = len(id2label)
|
| 77 |
+
else:
|
| 78 |
+
# Fallback: infer from test data if available
|
| 79 |
+
if has_labels:
|
| 80 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 81 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 82 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 83 |
+
num_labels = len(unique_classes)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("Cannot determine number of labels without config or test labels")
|
| 86 |
+
else:
|
| 87 |
+
# Fallback: infer from test data if available
|
| 88 |
+
if has_labels:
|
| 89 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 90 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 91 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 92 |
+
num_labels = len(unique_classes)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Cannot find config.json and test data has no labels")
|
| 95 |
+
|
| 96 |
+
print(f"Number of classes: {num_labels}")
|
| 97 |
+
print(f"Label mapping: {id2label}")
|
| 98 |
+
|
| 99 |
+
# Load model
|
| 100 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
| 101 |
+
model = model.to(device)
|
| 102 |
+
model.eval()
|
| 103 |
+
print("Model loaded successfully!")
|
| 104 |
+
|
| 105 |
+
# -------------------------------------------------------------------
|
| 106 |
+
# 4. Run inference
|
| 107 |
+
# -------------------------------------------------------------------
|
| 108 |
+
print("\nRunning inference...")
|
| 109 |
+
|
| 110 |
+
all_predictions = []
|
| 111 |
+
all_probabilities = []
|
| 112 |
+
|
| 113 |
+
# Process in batches for efficiency
|
| 114 |
+
for i in range(0, len(test_df), BATCH_SIZE):
|
| 115 |
+
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
|
| 116 |
+
|
| 117 |
+
# Tokenize
|
| 118 |
+
inputs = tokenizer(
|
| 119 |
+
batch_texts,
|
| 120 |
+
padding=True,
|
| 121 |
+
truncation=True,
|
| 122 |
+
max_length=MAX_LENGTH,
|
| 123 |
+
return_tensors="pt",
|
| 124 |
+
).to(device)
|
| 125 |
+
|
| 126 |
+
# Predict
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
outputs = model(**inputs)
|
| 129 |
+
logits = outputs.logits
|
| 130 |
+
probs = torch.softmax(logits, dim=-1)
|
| 131 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 132 |
+
|
| 133 |
+
all_predictions.extend(predictions.cpu().numpy())
|
| 134 |
+
all_probabilities.extend(probs.cpu().numpy())
|
| 135 |
+
|
| 136 |
+
if (i // BATCH_SIZE + 1) % 10 == 0:
|
| 137 |
+
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
|
| 138 |
+
|
| 139 |
+
print(f"Inference complete! Processed {len(all_predictions)} samples")
|
| 140 |
+
|
| 141 |
+
# -------------------------------------------------------------------
|
| 142 |
+
# 5. Save predictions
|
| 143 |
+
# -------------------------------------------------------------------
|
| 144 |
+
# Convert predictions to class names (1-9)
|
| 145 |
+
predicted_classes = [id2label[pred] for pred in all_predictions]
|
| 146 |
+
|
| 147 |
+
# Add predictions to dataframe
|
| 148 |
+
test_df["Predicted_Class"] = predicted_classes
|
| 149 |
+
test_df["Predicted_Label_ID"] = all_predictions
|
| 150 |
+
|
| 151 |
+
# Add probability for each class
|
| 152 |
+
for idx, class_name in id2label.items():
|
| 153 |
+
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
|
| 154 |
+
|
| 155 |
+
# Add confidence (max probability)
|
| 156 |
+
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
|
| 157 |
+
|
| 158 |
+
# Save results
|
| 159 |
+
test_df.to_csv(OUTPUT_FILE, index=False)
|
| 160 |
+
print(f"\nPredictions saved to: {OUTPUT_FILE}")
|
| 161 |
+
|
| 162 |
+
# -------------------------------------------------------------------
|
| 163 |
+
# 6. Compute metrics (if labels available)
|
| 164 |
+
# -------------------------------------------------------------------
|
| 165 |
+
if has_labels:
|
| 166 |
+
print("\n" + "="*80)
|
| 167 |
+
print("EVALUATION METRICS")
|
| 168 |
+
print("="*80)
|
| 169 |
+
|
| 170 |
+
# Convert true labels to indices
|
| 171 |
+
true_labels = test_df["Class"].map(label2id).values
|
| 172 |
+
pred_labels = np.array(all_predictions)
|
| 173 |
+
|
| 174 |
+
# Overall metrics
|
| 175 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 176 |
+
print(f"\nAccuracy: {accuracy:.4f}")
|
| 177 |
+
|
| 178 |
+
# Weighted metrics (accounts for class imbalance)
|
| 179 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 180 |
+
true_labels, pred_labels, average='weighted', zero_division=0
|
| 181 |
+
)
|
| 182 |
+
print(f"\nWeighted Metrics:")
|
| 183 |
+
print(f" Precision: {precision_w:.4f}")
|
| 184 |
+
print(f" Recall: {recall_w:.4f}")
|
| 185 |
+
print(f" F1 Score: {f1_w:.4f}")
|
| 186 |
+
|
| 187 |
+
# Macro metrics (treats all classes equally)
|
| 188 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 189 |
+
true_labels, pred_labels, average='macro', zero_division=0
|
| 190 |
+
)
|
| 191 |
+
print(f"\nMacro Metrics:")
|
| 192 |
+
print(f" Precision: {precision_m:.4f}")
|
| 193 |
+
print(f" Recall: {recall_m:.4f}")
|
| 194 |
+
print(f" F1 Score: {f1_m:.4f}")
|
| 195 |
+
|
| 196 |
+
# Per-class metrics
|
| 197 |
+
print(f"\nPer-Class Metrics:")
|
| 198 |
+
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
|
| 199 |
+
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
|
| 200 |
+
true_labels, pred_labels, average=None, zero_division=0
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
for idx in range(num_labels):
|
| 204 |
+
class_name = id2label[idx]
|
| 205 |
+
print(f"\n Class {class_name}:")
|
| 206 |
+
print(f" Precision: {per_class_precision[idx]:.4f}")
|
| 207 |
+
print(f" Recall: {per_class_recall[idx]:.4f}")
|
| 208 |
+
print(f" F1 Score: {per_class_f1[idx]:.4f}")
|
| 209 |
+
print(f" Support: {int(support[idx])}")
|
| 210 |
+
|
| 211 |
+
# Classification report
|
| 212 |
+
print("\n" + "="*80)
|
| 213 |
+
print("DETAILED CLASSIFICATION REPORT")
|
| 214 |
+
print("="*80)
|
| 215 |
+
target_names = [str(id2label[i]) for i in range(num_labels)]
|
| 216 |
+
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
|
| 217 |
+
|
| 218 |
+
# Confusion matrix
|
| 219 |
+
print("\n" + "="*80)
|
| 220 |
+
print("CONFUSION MATRIX")
|
| 221 |
+
print("="*80)
|
| 222 |
+
cm = confusion_matrix(true_labels, pred_labels)
|
| 223 |
+
|
| 224 |
+
# Print confusion matrix with labels
|
| 225 |
+
print("\nTrue \\ Predicted", end="")
|
| 226 |
+
for i in range(num_labels):
|
| 227 |
+
print(f"\t{id2label[i]}", end="")
|
| 228 |
+
print()
|
| 229 |
+
|
| 230 |
+
for i in range(num_labels):
|
| 231 |
+
print(f"{id2label[i]:<15}", end="")
|
| 232 |
+
for j in range(num_labels):
|
| 233 |
+
print(f"\t{cm[i][j]}", end="")
|
| 234 |
+
print()
|
| 235 |
+
|
| 236 |
+
# Save confusion matrix to CSV
|
| 237 |
+
cm_df = pd.DataFrame(
|
| 238 |
+
cm,
|
| 239 |
+
index=[str(id2label[i]) for i in range(num_labels)],
|
| 240 |
+
columns=[str(id2label[i]) for i in range(num_labels)]
|
| 241 |
+
)
|
| 242 |
+
cm_df.to_csv("./confusion_matrix_marbertv2_cpt_ft.csv")
|
| 243 |
+
print("\nConfusion matrix saved to: ./confusion_matrix_marbertv2_cpt_ft.csv")
|
| 244 |
+
|
| 245 |
+
# -------------------------------------------------------------------
|
| 246 |
+
# 7. Show sample predictions
|
| 247 |
+
# -------------------------------------------------------------------
|
| 248 |
+
print("\n" + "="*80)
|
| 249 |
+
print("SAMPLE PREDICTIONS (CPT+Finetuned MARBERTv2)")
|
| 250 |
+
print("="*80)
|
| 251 |
+
|
| 252 |
+
# Show first 5 predictions
|
| 253 |
+
num_samples = min(5, len(test_df))
|
| 254 |
+
for i in range(num_samples):
|
| 255 |
+
print(f"\nSample {i+1}:")
|
| 256 |
+
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
|
| 257 |
+
if has_labels:
|
| 258 |
+
print(f"True Class: {test_df['Class'].iloc[i]}")
|
| 259 |
+
print(f"Predicted Class: {predicted_classes[i]}")
|
| 260 |
+
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
|
| 261 |
+
print(f"Probabilities:")
|
| 262 |
+
for idx, class_name in id2label.items():
|
| 263 |
+
print(f" Class {class_name}: {all_probabilities[i][idx]:.4f}")
|
| 264 |
+
|
| 265 |
+
print("\n" + "="*80)
|
| 266 |
+
print("Inference completed successfully!")
|
| 267 |
+
print(f"Model used: CPT+Finetuned MARBERTv2 from {MODEL_DIR}")
|
| 268 |
+
print("="*80)
|
Code/inference_dziribert.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Inference script for trained DziriBERT telecom classification model.
|
| 6 |
+
|
| 7 |
+
Loads the trained model and runs predictions on test.csv
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForSequenceClassification,
|
| 18 |
+
AutoConfig,
|
| 19 |
+
)
|
| 20 |
+
from sklearn.metrics import (
|
| 21 |
+
accuracy_score,
|
| 22 |
+
f1_score,
|
| 23 |
+
precision_recall_fscore_support,
|
| 24 |
+
classification_report,
|
| 25 |
+
confusion_matrix,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# -------------------------------------------------------------------
|
| 29 |
+
# 1. Paths & config
|
| 30 |
+
# -------------------------------------------------------------------
|
| 31 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 32 |
+
MODEL_DIR = "./telecom_dziribert_final"
|
| 33 |
+
OUTPUT_FILE = "./test_predictions_dziribert.csv"
|
| 34 |
+
|
| 35 |
+
MAX_LENGTH = 512
|
| 36 |
+
BATCH_SIZE = 64
|
| 37 |
+
|
| 38 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
print(f"Using device: {device}")
|
| 40 |
+
|
| 41 |
+
# -------------------------------------------------------------------
|
| 42 |
+
# 2. Load test data
|
| 43 |
+
# -------------------------------------------------------------------
|
| 44 |
+
print(f"Loading test data from: {TEST_FILE}")
|
| 45 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 46 |
+
print(f"Test samples: {len(test_df)}")
|
| 47 |
+
print(f"Columns: {test_df.columns.tolist()}")
|
| 48 |
+
|
| 49 |
+
# Check if test data has labels
|
| 50 |
+
has_labels = "Class" in test_df.columns
|
| 51 |
+
if has_labels:
|
| 52 |
+
print("Test data contains labels - will compute metrics")
|
| 53 |
+
else:
|
| 54 |
+
print("Test data has no labels - will only generate predictions")
|
| 55 |
+
|
| 56 |
+
# -------------------------------------------------------------------
|
| 57 |
+
# 3. Load model and tokenizer
|
| 58 |
+
# -------------------------------------------------------------------
|
| 59 |
+
print(f"\nLoading model from: {MODEL_DIR}")
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 61 |
+
|
| 62 |
+
# Load config to get label mappings
|
| 63 |
+
config_path = os.path.join(MODEL_DIR, "config.json")
|
| 64 |
+
if os.path.exists(config_path):
|
| 65 |
+
import json
|
| 66 |
+
with open(config_path, 'r') as f:
|
| 67 |
+
config_data = json.load(f)
|
| 68 |
+
|
| 69 |
+
if 'id2label' in config_data:
|
| 70 |
+
id2label = {int(k): int(v) for k, v in config_data['id2label'].items()}
|
| 71 |
+
# Create label2id with both string and int keys for robustness
|
| 72 |
+
label2id = {}
|
| 73 |
+
for k, v in id2label.items():
|
| 74 |
+
label2id[v] = k # int key -> int value
|
| 75 |
+
label2id[str(v)] = k # string key -> int value
|
| 76 |
+
num_labels = len(id2label)
|
| 77 |
+
else:
|
| 78 |
+
# Fallback: infer from test data if available
|
| 79 |
+
if has_labels:
|
| 80 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 81 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 82 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 83 |
+
num_labels = len(unique_classes)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("Cannot determine number of labels without config or test labels")
|
| 86 |
+
else:
|
| 87 |
+
# Fallback: infer from test data if available
|
| 88 |
+
if has_labels:
|
| 89 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 90 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 91 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 92 |
+
num_labels = len(unique_classes)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Cannot find config.json and test data has no labels")
|
| 95 |
+
|
| 96 |
+
print(f"Number of classes: {num_labels}")
|
| 97 |
+
print(f"Label mapping: {id2label}")
|
| 98 |
+
|
| 99 |
+
# Load model
|
| 100 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
| 101 |
+
model = model.to(device)
|
| 102 |
+
model.eval()
|
| 103 |
+
print("Model loaded successfully!")
|
| 104 |
+
|
| 105 |
+
# -------------------------------------------------------------------
|
| 106 |
+
# 4. Run inference
|
| 107 |
+
# -------------------------------------------------------------------
|
| 108 |
+
print("\nRunning inference...")
|
| 109 |
+
|
| 110 |
+
all_predictions = []
|
| 111 |
+
all_probabilities = []
|
| 112 |
+
|
| 113 |
+
# Process in batches for efficiency
|
| 114 |
+
for i in range(0, len(test_df), BATCH_SIZE):
|
| 115 |
+
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
|
| 116 |
+
|
| 117 |
+
# Tokenize
|
| 118 |
+
inputs = tokenizer(
|
| 119 |
+
batch_texts,
|
| 120 |
+
padding=True,
|
| 121 |
+
truncation=True,
|
| 122 |
+
max_length=MAX_LENGTH,
|
| 123 |
+
return_tensors="pt",
|
| 124 |
+
).to(device)
|
| 125 |
+
|
| 126 |
+
# Predict
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
outputs = model(**inputs)
|
| 129 |
+
logits = outputs.logits
|
| 130 |
+
probs = torch.softmax(logits, dim=-1)
|
| 131 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 132 |
+
|
| 133 |
+
all_predictions.extend(predictions.cpu().numpy())
|
| 134 |
+
all_probabilities.extend(probs.cpu().numpy())
|
| 135 |
+
|
| 136 |
+
if (i // BATCH_SIZE + 1) % 10 == 0:
|
| 137 |
+
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
|
| 138 |
+
|
| 139 |
+
print(f"Inference complete! Processed {len(all_predictions)} samples")
|
| 140 |
+
|
| 141 |
+
# -------------------------------------------------------------------
|
| 142 |
+
# 5. Save predictions
|
| 143 |
+
# -------------------------------------------------------------------
|
| 144 |
+
# Convert predictions to class names (1-9)
|
| 145 |
+
predicted_classes = [id2label[pred] for pred in all_predictions]
|
| 146 |
+
|
| 147 |
+
# Add predictions to dataframe
|
| 148 |
+
test_df["Predicted_Class"] = predicted_classes
|
| 149 |
+
test_df["Predicted_Label_ID"] = all_predictions
|
| 150 |
+
|
| 151 |
+
# Add probability for each class
|
| 152 |
+
for idx, class_name in id2label.items():
|
| 153 |
+
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
|
| 154 |
+
|
| 155 |
+
# Add confidence (max probability)
|
| 156 |
+
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
|
| 157 |
+
|
| 158 |
+
# Save results
|
| 159 |
+
test_df.to_csv(OUTPUT_FILE, index=False)
|
| 160 |
+
print(f"\nPredictions saved to: {OUTPUT_FILE}")
|
| 161 |
+
|
| 162 |
+
# -------------------------------------------------------------------
|
| 163 |
+
# 6. Compute metrics (if labels available)
|
| 164 |
+
# -------------------------------------------------------------------
|
| 165 |
+
if has_labels:
|
| 166 |
+
print("\n" + "="*80)
|
| 167 |
+
print("EVALUATION METRICS")
|
| 168 |
+
print("="*80)
|
| 169 |
+
|
| 170 |
+
# Convert true labels to indices
|
| 171 |
+
true_labels = test_df["Class"].map(label2id).values
|
| 172 |
+
pred_labels = np.array(all_predictions)
|
| 173 |
+
|
| 174 |
+
# Overall metrics
|
| 175 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 176 |
+
print(f"\nAccuracy: {accuracy:.4f}")
|
| 177 |
+
|
| 178 |
+
# Weighted metrics (accounts for class imbalance)
|
| 179 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 180 |
+
true_labels, pred_labels, average='weighted', zero_division=0
|
| 181 |
+
)
|
| 182 |
+
print(f"\nWeighted Metrics:")
|
| 183 |
+
print(f" Precision: {precision_w:.4f}")
|
| 184 |
+
print(f" Recall: {recall_w:.4f}")
|
| 185 |
+
print(f" F1 Score: {f1_w:.4f}")
|
| 186 |
+
|
| 187 |
+
# Macro metrics (treats all classes equally)
|
| 188 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 189 |
+
true_labels, pred_labels, average='macro', zero_division=0
|
| 190 |
+
)
|
| 191 |
+
print(f"\nMacro Metrics:")
|
| 192 |
+
print(f" Precision: {precision_m:.4f}")
|
| 193 |
+
print(f" Recall: {recall_m:.4f}")
|
| 194 |
+
print(f" F1 Score: {f1_m:.4f}")
|
| 195 |
+
|
| 196 |
+
# Per-class metrics
|
| 197 |
+
print(f"\nPer-Class Metrics:")
|
| 198 |
+
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
|
| 199 |
+
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
|
| 200 |
+
true_labels, pred_labels, average=None, zero_division=0
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
for idx in range(num_labels):
|
| 204 |
+
class_name = id2label[idx]
|
| 205 |
+
print(f"\n Class {class_name}:")
|
| 206 |
+
print(f" Precision: {per_class_precision[idx]:.4f}")
|
| 207 |
+
print(f" Recall: {per_class_recall[idx]:.4f}")
|
| 208 |
+
print(f" F1 Score: {per_class_f1[idx]:.4f}")
|
| 209 |
+
print(f" Support: {int(support[idx])}")
|
| 210 |
+
|
| 211 |
+
# Classification report
|
| 212 |
+
print("\n" + "="*80)
|
| 213 |
+
print("DETAILED CLASSIFICATION REPORT")
|
| 214 |
+
print("="*80)
|
| 215 |
+
target_names = [str(id2label[i]) for i in range(num_labels)]
|
| 216 |
+
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
|
| 217 |
+
|
| 218 |
+
# Confusion matrix
|
| 219 |
+
print("\n" + "="*80)
|
| 220 |
+
print("CONFUSION MATRIX")
|
| 221 |
+
print("="*80)
|
| 222 |
+
cm = confusion_matrix(true_labels, pred_labels)
|
| 223 |
+
|
| 224 |
+
# Print confusion matrix with labels
|
| 225 |
+
print("\nTrue \\ Predicted", end="")
|
| 226 |
+
for i in range(num_labels):
|
| 227 |
+
print(f"\t{id2label[i]}", end="")
|
| 228 |
+
print()
|
| 229 |
+
|
| 230 |
+
for i in range(num_labels):
|
| 231 |
+
print(f"{id2label[i]:<15}", end="")
|
| 232 |
+
for j in range(num_labels):
|
| 233 |
+
print(f"\t{cm[i][j]}", end="")
|
| 234 |
+
print()
|
| 235 |
+
|
| 236 |
+
# Save confusion matrix to CSV
|
| 237 |
+
cm_df = pd.DataFrame(
|
| 238 |
+
cm,
|
| 239 |
+
index=[str(id2label[i]) for i in range(num_labels)],
|
| 240 |
+
columns=[str(id2label[i]) for i in range(num_labels)]
|
| 241 |
+
)
|
| 242 |
+
cm_df.to_csv("./confusion_matrix_dziribert.csv")
|
| 243 |
+
print("\nConfusion matrix saved to: ./confusion_matrix_dziribert.csv")
|
| 244 |
+
|
| 245 |
+
# -------------------------------------------------------------------
|
| 246 |
+
# 7. Show sample predictions
|
| 247 |
+
# -------------------------------------------------------------------
|
| 248 |
+
print("\n" + "="*80)
|
| 249 |
+
print("SAMPLE PREDICTIONS")
|
| 250 |
+
print("="*80)
|
| 251 |
+
|
| 252 |
+
# Show first 5 predictions
|
| 253 |
+
num_samples = min(5, len(test_df))
|
| 254 |
+
for i in range(num_samples):
|
| 255 |
+
print(f"\nSample {i+1}:")
|
| 256 |
+
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
|
| 257 |
+
if has_labels:
|
| 258 |
+
print(f"True Class: {test_df['Class'].iloc[i]}")
|
| 259 |
+
print(f"Predicted Class: {predicted_classes[i]}")
|
| 260 |
+
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
|
| 261 |
+
print(f"Probabilities:")
|
| 262 |
+
for idx, class_name in id2label.items():
|
| 263 |
+
print(f" Class {class_name}: {all_probabilities[i][idx]:.4f}")
|
| 264 |
+
|
| 265 |
+
print("\n" + "="*80)
|
| 266 |
+
print("Inference completed successfully!")
|
| 267 |
+
print("="*80)
|
Code/inference_gemma3.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Gemma 3 12B - Fast Inference for Classification
|
| 5 |
+
|
| 6 |
+
Load the fine-tuned Gemma 3 model and run inference on test set.
|
| 7 |
+
Uses batch processing for faster inference.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python inference_gemma3.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
import torch
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 21 |
+
from peft import PeftModel
|
| 22 |
+
|
| 23 |
+
# ---------------------------
|
| 24 |
+
# Paths & Config
|
| 25 |
+
# ---------------------------
|
| 26 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 27 |
+
BASE_MODEL = "google/gemma-3-4b-it" # Must match training base model
|
| 28 |
+
ADAPTER_PATH = "./gemma3_classification_ft" # LoRA adapter path
|
| 29 |
+
|
| 30 |
+
MAX_LENGTH = 2048
|
| 31 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
print(f"Device: {DEVICE}")
|
| 33 |
+
|
| 34 |
+
# Enable TF32 for A100
|
| 35 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 36 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 37 |
+
|
| 38 |
+
text_col = "Commentaire client"
|
| 39 |
+
|
| 40 |
+
# ===========================================================================
|
| 41 |
+
# System Prompt and Few-Shot Examples (same as training)
|
| 42 |
+
# ===========================================================================
|
| 43 |
+
SYSTEM_PROMPT = """You are an expert Algerian linguist and data labeler. Your task is to classify customer comments from Algérie Télécom's social media into one of 9 specific categories.
|
| 44 |
+
|
| 45 |
+
## CLASSES (DETAILED DESCRIPTIONS)
|
| 46 |
+
- **Class 1 (Wish/Positive Anticipation):** Comments expressing a wish, a hopeful anticipation, or general positive feedback/appreciation for future services or offers.
|
| 47 |
+
- **Class 2 (Complaint: Equipment/Supply):** Comments complaining about the lack, unavailability, or delay in the supply of necessary equipment (e.g., modems, fiber optics devices).
|
| 48 |
+
- **Class 3 (Complaint: Marketing/Advertising):** Comments criticizing advertisements, marketing campaigns, or their lack of realism/meaning.
|
| 49 |
+
- **Class 4 (Complaint: Installation/Deployment):** Comments about delays, stoppages, or failure in service installation, network expansion, or fiber optics deployment (e.g., digging issues).
|
| 50 |
+
- **Class 5 (Inquiry/Request for Information):** Comments asking for eligibility, connection dates, service status, coverage details, or specific contact information.
|
| 51 |
+
- **Class 6 (Complaint: Technical Support/Intervention):** Comments regarding delays in repair interventions, issues with technical staff competence, or unsatisfactory customer service agency visits.
|
| 52 |
+
- **Class 7 (Pricing/Service Enhancement):** Comments focused on pricing, requests for cost reduction, or suggestions for general service/app functionality enhancements.
|
| 53 |
+
- **Class 8 (Complaint: Total Service Outage/Disconnection):** Comments indicating a complete, sustained loss of service (e.g., no phone, no internet, total disconnection).
|
| 54 |
+
- **Class 9 (Complaint: Service Performance/Quality):** Comments about technical issues impacting performance (e.g., slow speed, high latency, broken website/portal, coverage claims).
|
| 55 |
+
|
| 56 |
+
Respond with ONLY the class number (1-9). Do not include any explanation."""
|
| 57 |
+
|
| 58 |
+
FEW_SHOT_STRING = """
|
| 59 |
+
Comment: إن شاء الله يكون عرض صحاب 300 و 500 ميجا فيبر ياربي
|
| 60 |
+
Class: 1
|
| 61 |
+
|
| 62 |
+
Comment: الف مبروووك..
|
| 63 |
+
Class: 1
|
| 64 |
+
|
| 65 |
+
Comment: - إتصالات الجزائر شكرا اتمنى لكم دوام الصحة والعافية
|
| 66 |
+
Class: 1
|
| 67 |
+
|
| 68 |
+
Comment: C une fierté de faire partie de cette grande entreprise Algérienne de haute technologie et haute qualité
|
| 69 |
+
Class: 1
|
| 70 |
+
|
| 71 |
+
Comment: اتمنى لكم مزيد من التألق
|
| 72 |
+
Class: 1
|
| 73 |
+
|
| 74 |
+
Comment: زعما جابو المودام ؟
|
| 75 |
+
Class: 2
|
| 76 |
+
|
| 77 |
+
Comment: وفرو أجهزة مودام الباقي ساهل !
|
| 78 |
+
Class: 2
|
| 79 |
+
|
| 80 |
+
Comment: واش الفايدة تع العرض هذا هو اصلا لي مودام مهوش متوفر رنا قريب عام وحنا ستناو في جد موام هذا
|
| 81 |
+
Class: 2
|
| 82 |
+
|
| 83 |
+
Comment: Depuis un an et demi qu'on a installé w ma kan walou
|
| 84 |
+
Class: 2
|
| 85 |
+
|
| 86 |
+
Comment: قتلتونا بلكذب المودام غير متوفر عندي 4 أشهر ملي حطيت الطلب في ولاية خنشلة و مزال ماجابوش المودام
|
| 87 |
+
Class: 2
|
| 88 |
+
|
| 89 |
+
Comment: عندكم احساس و لا شريوه كما قالو خوتنا لمصريين
|
| 90 |
+
Class: 3
|
| 91 |
+
|
| 92 |
+
Comment: Kamel Dahmane الفايبر؟ مستحيل كامل عاجبتهم
|
| 93 |
+
Class: 3
|
| 94 |
+
|
| 95 |
+
Comment: ههههه نخلص مليون عادي كون يركبونا الفيبر 😂😂😂😂😂 كرهنا من 144p
|
| 96 |
+
Class: 3
|
| 97 |
+
|
| 98 |
+
Comment: إشهار بدون معنه
|
| 99 |
+
Class: 3
|
| 100 |
+
|
| 101 |
+
Comment: المشروع متوقف منذ اشهر
|
| 102 |
+
Class: 4
|
| 103 |
+
|
| 104 |
+
Comment: نتمنى تكملو في ايسطو وهران في اقرب وقت رانا نعانو مع ADSL
|
| 105 |
+
Class: 4
|
| 106 |
+
|
| 107 |
+
Comment: Fibre كاش واحد وصلوله الفيبر؟
|
| 108 |
+
Class: 4
|
| 109 |
+
|
| 110 |
+
Comment: ما هو الجديد وانا مزال ماعنديش الفيبر رغم الطلب ولالحاح
|
| 111 |
+
Class: 4
|
| 112 |
+
|
| 113 |
+
Comment: علبة الفيبر راكبة في الحي و لكن لا يوجد توصيل للمنزل للان
|
| 114 |
+
Class: 4
|
| 115 |
+
|
| 116 |
+
Comment: modem
|
| 117 |
+
Class: 5
|
| 118 |
+
|
| 119 |
+
Comment: يعني كي نطلعها ثلاثون ميغا كارطة تاع مائة الف قداه تحكملي؟
|
| 120 |
+
Class: 5
|
| 121 |
+
|
| 122 |
+
Comment: سآل الأماكن لي ما فيهاش الألياف البصرية إذا جابولنا الألياف السرعة تكون محدودة كيما ف ADSL؟
|
| 123 |
+
Class: 5
|
| 124 |
+
|
| 125 |
+
Comment: ماعرف كاش خبر على ايدوم 4G ماعرف تبقى قرد العش
|
| 126 |
+
Class: 5
|
| 127 |
+
|
| 128 |
+
Comment: هل متوفرة في حي عدل 1046 مسكن دويرة
|
| 129 |
+
Class: 5
|
| 130 |
+
|
| 131 |
+
Comment: عرض 20 ميجا نحيوه مدام مش قادرين تعطيونا حقنا
|
| 132 |
+
Class: 6
|
| 133 |
+
|
| 134 |
+
Comment: 4 سنوات وحنا نخلصو فالدار ماشفنا حتى bonus
|
| 135 |
+
Class: 6
|
| 136 |
+
|
| 137 |
+
Comment: لماذا التغيير في الرقم بدون تغيير سرعة التدفق هل من أجل الإشهار وفقط انا غير من 50 ميغا إلا 200 ميغا نظريا تغيرت وفي الواقع بقت قياس أقل من 50 ميغا
|
| 138 |
+
Class: 6
|
| 139 |
+
|
| 140 |
+
Comment: انا طلعت تدفق انترنات من 15 الى 20 عبر تطبيق my idoom لاكن سرعة لم تتغير
|
| 141 |
+
Class: 6
|
| 142 |
+
|
| 143 |
+
Comment: نقصوا الاسعار بزااااف غالية
|
| 144 |
+
Class: 7
|
| 145 |
+
|
| 146 |
+
Comment: علاه ماديروش في التطبيق خاصية التوقيف المؤقت للانترانات
|
| 147 |
+
Class: 7
|
| 148 |
+
|
| 149 |
+
Comment: وفرونا من بعد اي ساهلة
|
| 150 |
+
Class: 7
|
| 151 |
+
|
| 152 |
+
Comment: لازم ترجعو اتصال بتطبيقات الدفع بلا انترنت و مجاني ريقلوها يا اتصالات الجزائر
|
| 153 |
+
Class: 7
|
| 154 |
+
|
| 155 |
+
Comment: Promotion fin d'année ADSL idoom
|
| 156 |
+
Class: 7
|
| 157 |
+
|
| 158 |
+
Comment: رانا بلا تلفون ولا انترنت
|
| 159 |
+
Class: 8
|
| 160 |
+
|
| 161 |
+
Comment: ثلاثة اشهر بلا انترنت
|
| 162 |
+
Class: 8
|
| 163 |
+
|
| 164 |
+
Comment: votre site espace client ne fonctionne pas pourquoi?
|
| 165 |
+
Class: 8
|
| 166 |
+
|
| 167 |
+
Comment: ما عندنا الانترنيت ما نخلصوها من الدار
|
| 168 |
+
Class: 8
|
| 169 |
+
|
| 170 |
+
Comment: مشكل في 1.200جيق فيبر مدام نوكيا مخرج الانترنت 1جيق فقط كفاش راح تحلو هذا مشكل ومشكل ثاني فضاء الزبون ميمشيش مندو شهر
|
| 171 |
+
Class: 8
|
| 172 |
+
|
| 173 |
+
Comment: فضاء الزبون علاه منقدروش نسجلو فيه
|
| 174 |
+
Class: 9
|
| 175 |
+
|
| 176 |
+
Comment: هل موقع فضاء الزبون متوقف
|
| 177 |
+
Class: 9
|
| 178 |
+
|
| 179 |
+
Comment: ماراهيش توصل الفاتورة لا عن طريق الإيميل ولا عن طريق فضاء الزبون
|
| 180 |
+
Class: 9
|
| 181 |
+
|
| 182 |
+
Comment: فضاء الزبون قرابة 20 يوم متوقف!!!!!!؟؟؟؟؟
|
| 183 |
+
Class: 9
|
| 184 |
+
|
| 185 |
+
Comment: برج الكيفان اظنها من العاصمة خارج تغطيتكم....احشموا بركاو بلا كذب....طلعنا الصواريخ للفضاء....بصح بالكذب....
|
| 186 |
+
Class: 9"""
|
| 187 |
+
|
| 188 |
+
# ===========================================================================
|
| 189 |
+
# Text Preprocessing
|
| 190 |
+
# ===========================================================================
|
| 191 |
+
def preprocess_text(text):
|
| 192 |
+
"""Preprocess text: remove tatweel, emojis, URLs, phone numbers."""
|
| 193 |
+
if not isinstance(text, str):
|
| 194 |
+
return ""
|
| 195 |
+
|
| 196 |
+
# Remove URLs
|
| 197 |
+
text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
| 198 |
+
|
| 199 |
+
# Remove email addresses
|
| 200 |
+
text = re.sub(r'\S+@\S+', '', text)
|
| 201 |
+
|
| 202 |
+
# Remove phone numbers
|
| 203 |
+
text = re.sub(r'[\+]?[(]?[0-9]{1,4}[)]?[-\s\./0-9]{6,}', '', text)
|
| 204 |
+
text = re.sub(r'\b0[567]\d{8}\b', '', text)
|
| 205 |
+
text = re.sub(r'\b0[23]\d{7,8}\b', '', text)
|
| 206 |
+
|
| 207 |
+
# Remove mentions
|
| 208 |
+
text = re.sub(r'@\w+', '', text)
|
| 209 |
+
|
| 210 |
+
# Remove Arabic tatweel
|
| 211 |
+
text = re.sub(r'ـ+', '', text)
|
| 212 |
+
|
| 213 |
+
# Remove emojis
|
| 214 |
+
emoji_pattern = re.compile("["
|
| 215 |
+
u"\U0001F600-\U0001F64F"
|
| 216 |
+
u"\U0001F300-\U0001F5FF"
|
| 217 |
+
u"\U0001F680-\U0001F6FF"
|
| 218 |
+
u"\U0001F1E0-\U0001F1FF"
|
| 219 |
+
u"\U00002702-\U000027B0"
|
| 220 |
+
u"\U000024C2-\U0001F251"
|
| 221 |
+
u"\U0001f926-\U0001f937"
|
| 222 |
+
u"\U00010000-\U0010ffff"
|
| 223 |
+
u"\u2640-\u2642"
|
| 224 |
+
u"\u2600-\u2B55"
|
| 225 |
+
u"\u200d"
|
| 226 |
+
u"\u23cf"
|
| 227 |
+
u"\u23e9"
|
| 228 |
+
u"\u231a"
|
| 229 |
+
u"\ufe0f"
|
| 230 |
+
u"\u3030"
|
| 231 |
+
"]+", flags=re.UNICODE)
|
| 232 |
+
text = emoji_pattern.sub('', text)
|
| 233 |
+
|
| 234 |
+
# Remove platform names
|
| 235 |
+
text = re.sub(r'Algérie Télécom - إتصالات الجزائر', '', text, flags=re.IGNORECASE)
|
| 236 |
+
text = re.sub(r'Algérie Télécom', '', text, flags=re.IGNORECASE)
|
| 237 |
+
text = re.sub(r'إتصالات الجزائر', '', text)
|
| 238 |
+
|
| 239 |
+
# Remove repeated characters
|
| 240 |
+
text = re.sub(r'(.)\1{3,}', r'\1\1\1', text)
|
| 241 |
+
|
| 242 |
+
# Normalize whitespace
|
| 243 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 244 |
+
|
| 245 |
+
return text
|
| 246 |
+
|
| 247 |
+
def format_prompt(comment):
|
| 248 |
+
"""Format prompt for inference."""
|
| 249 |
+
user_prompt = f"""Here are some examples of how to classify comments:
|
| 250 |
+
|
| 251 |
+
{FEW_SHOT_STRING}
|
| 252 |
+
|
| 253 |
+
Now classify this comment:
|
| 254 |
+
Comment: {comment}
|
| 255 |
+
Class:"""
|
| 256 |
+
return user_prompt
|
| 257 |
+
|
| 258 |
+
def create_inference_prompt(comment, tokenizer):
|
| 259 |
+
"""Create full prompt for inference."""
|
| 260 |
+
clean_comment = preprocess_text(comment)
|
| 261 |
+
|
| 262 |
+
messages = [
|
| 263 |
+
{"role": "user", "content": SYSTEM_PROMPT + "\n\n" + format_prompt(clean_comment)}
|
| 264 |
+
]
|
| 265 |
+
|
| 266 |
+
# Apply chat template with generation prompt
|
| 267 |
+
text = tokenizer.apply_chat_template(
|
| 268 |
+
messages,
|
| 269 |
+
tokenize=False,
|
| 270 |
+
add_generation_prompt=True
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return text
|
| 274 |
+
|
| 275 |
+
def extract_class(generated_text):
|
| 276 |
+
"""Extract class number from generated text."""
|
| 277 |
+
try:
|
| 278 |
+
match = re.search(r'\b([1-9])\b', generated_text)
|
| 279 |
+
if match:
|
| 280 |
+
return int(match.group(1))
|
| 281 |
+
return 1 # Default
|
| 282 |
+
except:
|
| 283 |
+
return 1
|
| 284 |
+
|
| 285 |
+
# ===========================================================================
|
| 286 |
+
# Main Inference
|
| 287 |
+
# ===========================================================================
|
| 288 |
+
print("\n" + "="*70)
|
| 289 |
+
print("Gemma 3 4B - Fast Batch Inference")
|
| 290 |
+
print("="*70 + "\n")
|
| 291 |
+
|
| 292 |
+
# Load tokenizer
|
| 293 |
+
print(f"Loading tokenizer from: {BASE_MODEL}")
|
| 294 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
| 295 |
+
|
| 296 |
+
if tokenizer.pad_token is None:
|
| 297 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 298 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 299 |
+
|
| 300 |
+
tokenizer.padding_side = "left" # Left padding for batch generation
|
| 301 |
+
|
| 302 |
+
# Load base model
|
| 303 |
+
print(f"Loading base model from: {BASE_MODEL}")
|
| 304 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 305 |
+
BASE_MODEL,
|
| 306 |
+
torch_dtype=torch.bfloat16,
|
| 307 |
+
trust_remote_code=True,
|
| 308 |
+
device_map="auto",
|
| 309 |
+
attn_implementation="eager", # Flash attention not available
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Load LoRA adapter
|
| 313 |
+
print(f"Loading LoRA adapter from: {ADAPTER_PATH}")
|
| 314 |
+
model = PeftModel.from_pretrained(model, ADAPTER_PATH)
|
| 315 |
+
model.eval()
|
| 316 |
+
|
| 317 |
+
print(f"\nModel loaded successfully!")
|
| 318 |
+
|
| 319 |
+
# Load test data
|
| 320 |
+
print(f"\nLoading test data from: {TEST_FILE}")
|
| 321 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 322 |
+
print(f"Test samples: {len(test_df)}")
|
| 323 |
+
|
| 324 |
+
# Batch inference settings - use smaller batch for long prompts
|
| 325 |
+
BATCH_SIZE = 8 # Small batch to avoid OOM with long few-shot prompts
|
| 326 |
+
|
| 327 |
+
# Prepare all prompts first
|
| 328 |
+
print("\nPreparing prompts...")
|
| 329 |
+
all_prompts = []
|
| 330 |
+
for i in tqdm(range(len(test_df)), desc="Preparing"):
|
| 331 |
+
comment = str(test_df.iloc[i][text_col])
|
| 332 |
+
prompt = create_inference_prompt(comment, tokenizer)
|
| 333 |
+
all_prompts.append(prompt)
|
| 334 |
+
|
| 335 |
+
# Run batch inference
|
| 336 |
+
print(f"\nRunning batch inference (batch_size={BATCH_SIZE})...")
|
| 337 |
+
all_preds = []
|
| 338 |
+
num_batches = (len(all_prompts) + BATCH_SIZE - 1) // BATCH_SIZE
|
| 339 |
+
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
torch.cuda.empty_cache() # Clear cache before inference
|
| 342 |
+
|
| 343 |
+
for batch_idx in tqdm(range(num_batches), desc="Predicting"):
|
| 344 |
+
start_idx = batch_idx * BATCH_SIZE
|
| 345 |
+
end_idx = min(start_idx + BATCH_SIZE, len(all_prompts))
|
| 346 |
+
batch_prompts = all_prompts[start_idx:end_idx]
|
| 347 |
+
|
| 348 |
+
# Tokenize batch with padding
|
| 349 |
+
inputs = tokenizer(
|
| 350 |
+
batch_prompts,
|
| 351 |
+
return_tensors="pt",
|
| 352 |
+
truncation=True,
|
| 353 |
+
max_length=MAX_LENGTH,
|
| 354 |
+
padding=True, # Pad to longest in batch
|
| 355 |
+
)
|
| 356 |
+
input_lengths = [len(ids) for ids in inputs["input_ids"]]
|
| 357 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 358 |
+
|
| 359 |
+
# Generate batch (greedy decoding)
|
| 360 |
+
outputs = model.generate(
|
| 361 |
+
**inputs,
|
| 362 |
+
max_new_tokens=3, # Just need 1 digit
|
| 363 |
+
do_sample=False,
|
| 364 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 365 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Decode each sequence in the batch
|
| 369 |
+
for seq_idx, output_ids in enumerate(outputs):
|
| 370 |
+
# Get only the new tokens (after the input)
|
| 371 |
+
input_len = inputs["input_ids"].shape[1] # All padded to same length
|
| 372 |
+
generated_tokens = output_ids[input_len:]
|
| 373 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 374 |
+
|
| 375 |
+
# Extract class
|
| 376 |
+
pred_class = extract_class(generated_text)
|
| 377 |
+
all_preds.append(pred_class)
|
| 378 |
+
|
| 379 |
+
# Clear cache periodically
|
| 380 |
+
if batch_idx % 50 == 0:
|
| 381 |
+
torch.cuda.empty_cache()
|
| 382 |
+
|
| 383 |
+
# Save predictions
|
| 384 |
+
test_df["Predicted_Class"] = all_preds
|
| 385 |
+
output_file = "test_predictions_gemma3.csv"
|
| 386 |
+
test_df.to_csv(output_file, index=False)
|
| 387 |
+
print(f"\nPredictions saved to: {output_file}")
|
| 388 |
+
|
| 389 |
+
# Show sample predictions
|
| 390 |
+
print("\nSample predictions:")
|
| 391 |
+
for i in range(min(10, len(test_df))):
|
| 392 |
+
text = str(test_df.iloc[i][text_col])
|
| 393 |
+
text_display = text[:60] + "..." if len(text) > 60 else text
|
| 394 |
+
pred = test_df.iloc[i]["Predicted_Class"]
|
| 395 |
+
print(f" [{i+1}] Class {pred}: {text_display}")
|
| 396 |
+
|
| 397 |
+
# Class distribution
|
| 398 |
+
print("\nPrediction distribution:")
|
| 399 |
+
pred_counts = test_df["Predicted_Class"].value_counts().sort_index()
|
| 400 |
+
for class_label, count in pred_counts.items():
|
| 401 |
+
pct = count / len(test_df) * 100
|
| 402 |
+
print(f" Class {class_label}: {count:>5} samples ({pct:>5.1f}%)")
|
| 403 |
+
|
| 404 |
+
print("\n" + "="*70)
|
| 405 |
+
print("INFERENCE COMPLETE!")
|
| 406 |
+
print("="*70)
|
| 407 |
+
print(f"\nAdapter path: {ADAPTER_PATH}")
|
| 408 |
+
print(f"Test samples: {len(test_df)}")
|
| 409 |
+
print(f"Output file: {output_file}")
|
Code/inference_marbertv2.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Inference script for trained MARBERTv2 telecom classification model.
|
| 6 |
+
|
| 7 |
+
Loads the trained model and runs predictions on test.csv
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForSequenceClassification,
|
| 18 |
+
)
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
accuracy_score,
|
| 21 |
+
f1_score,
|
| 22 |
+
precision_recall_fscore_support,
|
| 23 |
+
classification_report,
|
| 24 |
+
confusion_matrix,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# -------------------------------------------------------------------
|
| 28 |
+
# 1. Paths & config
|
| 29 |
+
# -------------------------------------------------------------------
|
| 30 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 31 |
+
MODEL_DIR = "./telecom_marbertv2_final"
|
| 32 |
+
OUTPUT_FILE = "./test_predictions.csv"
|
| 33 |
+
|
| 34 |
+
MAX_LENGTH = 256
|
| 35 |
+
BATCH_SIZE = 64
|
| 36 |
+
|
| 37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
print(f"Using device: {device}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# -------------------------------------------------------------------
|
| 42 |
+
# 3. Load test data
|
| 43 |
+
# -------------------------------------------------------------------
|
| 44 |
+
print(f"Loading test data from: {TEST_FILE}")
|
| 45 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 46 |
+
print(f"Test samples: {len(test_df)}")
|
| 47 |
+
print(f"Columns: {test_df.columns.tolist()}")
|
| 48 |
+
|
| 49 |
+
# Check if test data has labels
|
| 50 |
+
has_labels = "Class" in test_df.columns
|
| 51 |
+
if has_labels:
|
| 52 |
+
print("Test data contains labels - will compute metrics")
|
| 53 |
+
else:
|
| 54 |
+
print("Test data has no labels - will only generate predictions")
|
| 55 |
+
|
| 56 |
+
# -------------------------------------------------------------------
|
| 57 |
+
# 4. Load model and tokenizer
|
| 58 |
+
# -------------------------------------------------------------------
|
| 59 |
+
print(f"\nLoading model from: {MODEL_DIR}")
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 61 |
+
|
| 62 |
+
# Load config to get label mappings
|
| 63 |
+
config_path = os.path.join(MODEL_DIR, "config.json")
|
| 64 |
+
if os.path.exists(config_path):
|
| 65 |
+
import json
|
| 66 |
+
with open(config_path, 'r') as f:
|
| 67 |
+
config_data = json.load(f)
|
| 68 |
+
|
| 69 |
+
if 'id2label' in config_data:
|
| 70 |
+
id2label = {int(k): v for k, v in config_data['id2label'].items()}
|
| 71 |
+
# Create label2id with both string and int keys for robustness
|
| 72 |
+
label2id = {}
|
| 73 |
+
for k, v in id2label.items():
|
| 74 |
+
label2id[v] = k # string key -> int value
|
| 75 |
+
try:
|
| 76 |
+
label2id[int(v)] = k # int key -> int value (if label is numeric)
|
| 77 |
+
except (ValueError, TypeError):
|
| 78 |
+
pass
|
| 79 |
+
num_labels = len(id2label)
|
| 80 |
+
else:
|
| 81 |
+
# Fallback: infer from test data if available
|
| 82 |
+
if has_labels:
|
| 83 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 84 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 85 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 86 |
+
num_labels = len(unique_classes)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Cannot determine number of labels without config or test labels")
|
| 89 |
+
else:
|
| 90 |
+
# Fallback: infer from test data if available
|
| 91 |
+
if has_labels:
|
| 92 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 93 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 94 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 95 |
+
num_labels = len(unique_classes)
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Cannot find config.json and test data has no labels")
|
| 98 |
+
|
| 99 |
+
print(f"Number of classes: {num_labels}")
|
| 100 |
+
print(f"Label mapping: {id2label}")
|
| 101 |
+
|
| 102 |
+
# Load model using AutoModelForSequenceClassification
|
| 103 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 104 |
+
MODEL_DIR,
|
| 105 |
+
num_labels=num_labels,
|
| 106 |
+
id2label=id2label,
|
| 107 |
+
label2id=label2id,
|
| 108 |
+
)
|
| 109 |
+
model = model.to(device)
|
| 110 |
+
model.eval()
|
| 111 |
+
print("Model loaded successfully!")
|
| 112 |
+
|
| 113 |
+
# -------------------------------------------------------------------
|
| 114 |
+
# 5. Run inference
|
| 115 |
+
# -------------------------------------------------------------------
|
| 116 |
+
print("\nRunning inference...")
|
| 117 |
+
|
| 118 |
+
all_predictions = []
|
| 119 |
+
all_probabilities = []
|
| 120 |
+
|
| 121 |
+
# Process in batches for efficiency
|
| 122 |
+
for i in range(0, len(test_df), BATCH_SIZE):
|
| 123 |
+
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
|
| 124 |
+
|
| 125 |
+
# Tokenize
|
| 126 |
+
inputs = tokenizer(
|
| 127 |
+
batch_texts,
|
| 128 |
+
padding=True,
|
| 129 |
+
truncation=True,
|
| 130 |
+
max_length=MAX_LENGTH,
|
| 131 |
+
return_tensors="pt",
|
| 132 |
+
).to(device)
|
| 133 |
+
|
| 134 |
+
# Predict
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
outputs = model(**inputs)
|
| 137 |
+
logits = outputs.logits
|
| 138 |
+
probs = torch.softmax(logits, dim=-1)
|
| 139 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 140 |
+
|
| 141 |
+
all_predictions.extend(predictions.cpu().numpy())
|
| 142 |
+
all_probabilities.extend(probs.cpu().numpy())
|
| 143 |
+
|
| 144 |
+
if (i // BATCH_SIZE + 1) % 10 == 0:
|
| 145 |
+
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
|
| 146 |
+
|
| 147 |
+
print(f"Inference complete! Processed {len(all_predictions)} samples")
|
| 148 |
+
|
| 149 |
+
# -------------------------------------------------------------------
|
| 150 |
+
# 6. Save predictions
|
| 151 |
+
# -------------------------------------------------------------------
|
| 152 |
+
# Convert predictions to class names
|
| 153 |
+
predicted_classes = [id2label[pred] for pred in all_predictions]
|
| 154 |
+
|
| 155 |
+
# Add predictions to dataframe
|
| 156 |
+
test_df["Predicted_Class"] = predicted_classes
|
| 157 |
+
test_df["Predicted_Label_ID"] = all_predictions
|
| 158 |
+
|
| 159 |
+
# Add probability for each class
|
| 160 |
+
for idx, class_name in id2label.items():
|
| 161 |
+
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
|
| 162 |
+
|
| 163 |
+
# Add confidence (max probability)
|
| 164 |
+
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
|
| 165 |
+
|
| 166 |
+
# Save results
|
| 167 |
+
test_df.to_csv(OUTPUT_FILE, index=False)
|
| 168 |
+
print(f"\nPredictions saved to: {OUTPUT_FILE}")
|
| 169 |
+
|
| 170 |
+
# -------------------------------------------------------------------
|
| 171 |
+
# 7. Compute metrics (if labels available)
|
| 172 |
+
# -------------------------------------------------------------------
|
| 173 |
+
if has_labels:
|
| 174 |
+
print("\n" + "="*80)
|
| 175 |
+
print("EVALUATION METRICS")
|
| 176 |
+
print("="*80)
|
| 177 |
+
|
| 178 |
+
# Convert true labels to indices
|
| 179 |
+
true_labels = test_df["Class"].map(label2id).values
|
| 180 |
+
pred_labels = np.array(all_predictions)
|
| 181 |
+
|
| 182 |
+
# Overall metrics
|
| 183 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 184 |
+
print(f"\nAccuracy: {accuracy:.4f}")
|
| 185 |
+
|
| 186 |
+
# Weighted metrics (accounts for class imbalance)
|
| 187 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 188 |
+
true_labels, pred_labels, average='weighted', zero_division=0
|
| 189 |
+
)
|
| 190 |
+
print(f"\nWeighted Metrics:")
|
| 191 |
+
print(f" Precision: {precision_w:.4f}")
|
| 192 |
+
print(f" Recall: {recall_w:.4f}")
|
| 193 |
+
print(f" F1 Score: {f1_w:.4f}")
|
| 194 |
+
|
| 195 |
+
# Macro metrics (treats all classes equally)
|
| 196 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 197 |
+
true_labels, pred_labels, average='macro', zero_division=0
|
| 198 |
+
)
|
| 199 |
+
print(f"\nMacro Metrics:")
|
| 200 |
+
print(f" Precision: {precision_m:.4f}")
|
| 201 |
+
print(f" Recall: {recall_m:.4f}")
|
| 202 |
+
print(f" F1 Score: {f1_m:.4f}")
|
| 203 |
+
|
| 204 |
+
# Per-class metrics
|
| 205 |
+
print(f"\nPer-Class Metrics:")
|
| 206 |
+
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
|
| 207 |
+
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
|
| 208 |
+
true_labels, pred_labels, average=None, zero_division=0
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
for idx in range(num_labels):
|
| 212 |
+
class_name = id2label[idx]
|
| 213 |
+
print(f"\n {class_name}:")
|
| 214 |
+
print(f" Precision: {per_class_precision[idx]:.4f}")
|
| 215 |
+
print(f" Recall: {per_class_recall[idx]:.4f}")
|
| 216 |
+
print(f" F1 Score: {per_class_f1[idx]:.4f}")
|
| 217 |
+
print(f" Support: {int(support[idx])}")
|
| 218 |
+
|
| 219 |
+
# Classification report
|
| 220 |
+
print("\n" + "="*80)
|
| 221 |
+
print("DETAILED CLASSIFICATION REPORT")
|
| 222 |
+
print("="*80)
|
| 223 |
+
target_names = [id2label[i] for i in range(num_labels)]
|
| 224 |
+
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
|
| 225 |
+
|
| 226 |
+
# Confusion matrix
|
| 227 |
+
print("\n" + "="*80)
|
| 228 |
+
print("CONFUSION MATRIX")
|
| 229 |
+
print("="*80)
|
| 230 |
+
cm = confusion_matrix(true_labels, pred_labels)
|
| 231 |
+
|
| 232 |
+
# Print confusion matrix with labels
|
| 233 |
+
print("\nTrue \\ Predicted", end="")
|
| 234 |
+
for i in range(num_labels):
|
| 235 |
+
print(f"\t{id2label[i][:8]}", end="")
|
| 236 |
+
print()
|
| 237 |
+
|
| 238 |
+
for i in range(num_labels):
|
| 239 |
+
print(f"{id2label[i][:15]:<15}", end="")
|
| 240 |
+
for j in range(num_labels):
|
| 241 |
+
print(f"\t{cm[i][j]}", end="")
|
| 242 |
+
print()
|
| 243 |
+
|
| 244 |
+
# Save confusion matrix to CSV
|
| 245 |
+
cm_df = pd.DataFrame(
|
| 246 |
+
cm,
|
| 247 |
+
index=[id2label[i] for i in range(num_labels)],
|
| 248 |
+
columns=[id2label[i] for i in range(num_labels)]
|
| 249 |
+
)
|
| 250 |
+
cm_df.to_csv("./confusion_matrix.csv")
|
| 251 |
+
print("\nConfusion matrix saved to: ./confusion_matrix.csv")
|
| 252 |
+
|
| 253 |
+
# -------------------------------------------------------------------
|
| 254 |
+
# 8. Show sample predictions
|
| 255 |
+
# -------------------------------------------------------------------
|
| 256 |
+
print("\n" + "="*80)
|
| 257 |
+
print("SAMPLE PREDICTIONS")
|
| 258 |
+
print("="*80)
|
| 259 |
+
|
| 260 |
+
# Show first 5 predictions
|
| 261 |
+
num_samples = min(5, len(test_df))
|
| 262 |
+
for i in range(num_samples):
|
| 263 |
+
print(f"\nSample {i+1}:")
|
| 264 |
+
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
|
| 265 |
+
if has_labels:
|
| 266 |
+
print(f"True Class: {test_df['Class'].iloc[i]}")
|
| 267 |
+
print(f"Predicted Class: {predicted_classes[i]}")
|
| 268 |
+
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
|
| 269 |
+
print(f"Probabilities:")
|
| 270 |
+
for idx, class_name in id2label.items():
|
| 271 |
+
print(f" {class_name}: {all_probabilities[i][idx]:.4f}")
|
| 272 |
+
|
| 273 |
+
print("\n" + "="*80)
|
| 274 |
+
print("Inference completed successfully!")
|
| 275 |
+
print("="*80)
|
Code/inference_marbertv2_cpt_ft.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Inference script for CPT+finetuned MARBERTv2 telecom classification model.
|
| 6 |
+
|
| 7 |
+
Loads the model from ./telecom_marbertv2_cpt_ft and runs predictions on test.csv
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForSequenceClassification,
|
| 18 |
+
AutoConfig,
|
| 19 |
+
)
|
| 20 |
+
from sklearn.metrics import (
|
| 21 |
+
accuracy_score,
|
| 22 |
+
f1_score,
|
| 23 |
+
precision_recall_fscore_support,
|
| 24 |
+
classification_report,
|
| 25 |
+
confusion_matrix,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# -------------------------------------------------------------------
|
| 29 |
+
# 1. Paths & config
|
| 30 |
+
# -------------------------------------------------------------------
|
| 31 |
+
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
|
| 32 |
+
MODEL_DIR = "./telecom_marbertv2_cpt_ft"
|
| 33 |
+
OUTPUT_FILE = "./test_predictions_telecom_marbertv2_cpt_ft.csv"
|
| 34 |
+
|
| 35 |
+
MAX_LENGTH = 256
|
| 36 |
+
BATCH_SIZE = 64
|
| 37 |
+
|
| 38 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
print(f"Using device: {device}")
|
| 40 |
+
|
| 41 |
+
# -------------------------------------------------------------------
|
| 42 |
+
# 2. Load test data
|
| 43 |
+
# -------------------------------------------------------------------
|
| 44 |
+
print(f"Loading test data from: {TEST_FILE}")
|
| 45 |
+
test_df = pd.read_csv(TEST_FILE)
|
| 46 |
+
print(f"Test samples: {len(test_df)}")
|
| 47 |
+
print(f"Columns: {test_df.columns.tolist()}")
|
| 48 |
+
|
| 49 |
+
# Check if test data has labels
|
| 50 |
+
has_labels = "Class" in test_df.columns
|
| 51 |
+
if has_labels:
|
| 52 |
+
print("Test data contains labels - will compute metrics")
|
| 53 |
+
else:
|
| 54 |
+
print("Test data has no labels - will only generate predictions")
|
| 55 |
+
|
| 56 |
+
# -------------------------------------------------------------------
|
| 57 |
+
# 3. Load model and tokenizer
|
| 58 |
+
# -------------------------------------------------------------------
|
| 59 |
+
print(f"\nLoading model from: {MODEL_DIR}")
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 61 |
+
|
| 62 |
+
# Load config to get label mappings
|
| 63 |
+
config_path = os.path.join(MODEL_DIR, "config.json")
|
| 64 |
+
if os.path.exists(config_path):
|
| 65 |
+
import json
|
| 66 |
+
with open(config_path, 'r') as f:
|
| 67 |
+
config_data = json.load(f)
|
| 68 |
+
|
| 69 |
+
if 'id2label' in config_data:
|
| 70 |
+
id2label = {int(k): int(v) for k, v in config_data['id2label'].items()}
|
| 71 |
+
# Create label2id with both string and int keys for robustness
|
| 72 |
+
label2id = {}
|
| 73 |
+
for k, v in id2label.items():
|
| 74 |
+
label2id[v] = k # int key -> int value
|
| 75 |
+
label2id[str(v)] = k # string key -> int value
|
| 76 |
+
num_labels = len(id2label)
|
| 77 |
+
else:
|
| 78 |
+
# Fallback: infer from test data if available
|
| 79 |
+
if has_labels:
|
| 80 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 81 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 82 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 83 |
+
num_labels = len(unique_classes)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("Cannot determine number of labels without config or test labels")
|
| 86 |
+
else:
|
| 87 |
+
# Fallback: infer from test data if available
|
| 88 |
+
if has_labels:
|
| 89 |
+
unique_classes = sorted(test_df["Class"].unique())
|
| 90 |
+
label2id = {label: idx for idx, label in enumerate(unique_classes)}
|
| 91 |
+
id2label = {idx: label for label, idx in label2id.items()}
|
| 92 |
+
num_labels = len(unique_classes)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Cannot find config.json and test data has no labels")
|
| 95 |
+
|
| 96 |
+
print(f"Number of classes: {num_labels}")
|
| 97 |
+
print(f"Label mapping: {id2label}")
|
| 98 |
+
|
| 99 |
+
# Load model
|
| 100 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
| 101 |
+
model = model.to(device)
|
| 102 |
+
model.eval()
|
| 103 |
+
print("Model loaded successfully!")
|
| 104 |
+
|
| 105 |
+
# -------------------------------------------------------------------
|
| 106 |
+
# 4. Run inference
|
| 107 |
+
# -------------------------------------------------------------------
|
| 108 |
+
print("\nRunning inference...")
|
| 109 |
+
|
| 110 |
+
all_predictions = []
|
| 111 |
+
all_probabilities = []
|
| 112 |
+
|
| 113 |
+
# Process in batches for efficiency
|
| 114 |
+
for i in range(0, len(test_df), BATCH_SIZE):
|
| 115 |
+
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
|
| 116 |
+
|
| 117 |
+
# Tokenize
|
| 118 |
+
inputs = tokenizer(
|
| 119 |
+
batch_texts,
|
| 120 |
+
padding=True,
|
| 121 |
+
truncation=True,
|
| 122 |
+
max_length=MAX_LENGTH,
|
| 123 |
+
return_tensors="pt",
|
| 124 |
+
).to(device)
|
| 125 |
+
|
| 126 |
+
# Predict
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
outputs = model(**inputs)
|
| 129 |
+
logits = outputs.logits
|
| 130 |
+
probs = torch.softmax(logits, dim=-1)
|
| 131 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 132 |
+
|
| 133 |
+
all_predictions.extend(predictions.cpu().numpy())
|
| 134 |
+
all_probabilities.extend(probs.cpu().numpy())
|
| 135 |
+
|
| 136 |
+
if (i // BATCH_SIZE + 1) % 10 == 0:
|
| 137 |
+
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
|
| 138 |
+
|
| 139 |
+
print(f"Inference complete! Processed {len(all_predictions)} samples")
|
| 140 |
+
|
| 141 |
+
# -------------------------------------------------------------------
|
| 142 |
+
# 5. Save predictions
|
| 143 |
+
# -------------------------------------------------------------------
|
| 144 |
+
# Convert predictions to class names (1-9)
|
| 145 |
+
predicted_classes = [id2label[pred] for pred in all_predictions]
|
| 146 |
+
|
| 147 |
+
# Add predictions to dataframe
|
| 148 |
+
test_df["Predicted_Class"] = predicted_classes
|
| 149 |
+
test_df["Predicted_Label_ID"] = all_predictions
|
| 150 |
+
|
| 151 |
+
# Add probability for each class
|
| 152 |
+
for idx, class_name in id2label.items():
|
| 153 |
+
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
|
| 154 |
+
|
| 155 |
+
# Add confidence (max probability)
|
| 156 |
+
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
|
| 157 |
+
|
| 158 |
+
# Save results
|
| 159 |
+
test_df.to_csv(OUTPUT_FILE, index=False)
|
| 160 |
+
print(f"\nPredictions saved to: {OUTPUT_FILE}")
|
| 161 |
+
|
| 162 |
+
# -------------------------------------------------------------------
|
| 163 |
+
# 6. Compute metrics (if labels available)
|
| 164 |
+
# -------------------------------------------------------------------
|
| 165 |
+
if has_labels:
|
| 166 |
+
print("\n" + "="*80)
|
| 167 |
+
print("EVALUATION METRICS")
|
| 168 |
+
print("="*80)
|
| 169 |
+
|
| 170 |
+
# Convert true labels to indices
|
| 171 |
+
true_labels = test_df["Class"].map(label2id).values
|
| 172 |
+
pred_labels = np.array(all_predictions)
|
| 173 |
+
|
| 174 |
+
# Overall metrics
|
| 175 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 176 |
+
print(f"\nAccuracy: {accuracy:.4f}")
|
| 177 |
+
|
| 178 |
+
# Weighted metrics (accounts for class imbalance)
|
| 179 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 180 |
+
true_labels, pred_labels, average='weighted', zero_division=0
|
| 181 |
+
)
|
| 182 |
+
print(f"\nWeighted Metrics:")
|
| 183 |
+
print(f" Precision: {precision_w:.4f}")
|
| 184 |
+
print(f" Recall: {recall_w:.4f}")
|
| 185 |
+
print(f" F1 Score: {f1_w:.4f}")
|
| 186 |
+
|
| 187 |
+
# Macro metrics (treats all classes equally)
|
| 188 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 189 |
+
true_labels, pred_labels, average='macro', zero_division=0
|
| 190 |
+
)
|
| 191 |
+
print(f"\nMacro Metrics:")
|
| 192 |
+
print(f" Precision: {precision_m:.4f}")
|
| 193 |
+
print(f" Recall: {recall_m:.4f}")
|
| 194 |
+
print(f" F1 Score: {f1_m:.4f}")
|
| 195 |
+
|
| 196 |
+
# Per-class metrics
|
| 197 |
+
print(f"\nPer-Class Metrics:")
|
| 198 |
+
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
|
| 199 |
+
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
|
| 200 |
+
true_labels, pred_labels, average=None, zero_division=0
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
for idx in range(num_labels):
|
| 204 |
+
class_name = id2label[idx]
|
| 205 |
+
print(f"\n Class {class_name}:")
|
| 206 |
+
print(f" Precision: {per_class_precision[idx]:.4f}")
|
| 207 |
+
print(f" Recall: {per_class_recall[idx]:.4f}")
|
| 208 |
+
print(f" F1 Score: {per_class_f1[idx]:.4f}")
|
| 209 |
+
print(f" Support: {int(support[idx])}")
|
| 210 |
+
|
| 211 |
+
# Classification report
|
| 212 |
+
print("\n" + "="*80)
|
| 213 |
+
print("DETAILED CLASSIFICATION REPORT")
|
| 214 |
+
print("="*80)
|
| 215 |
+
target_names = [str(id2label[i]) for i in range(num_labels)]
|
| 216 |
+
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
|
| 217 |
+
|
| 218 |
+
# Confusion matrix
|
| 219 |
+
print("\n" + "="*80)
|
| 220 |
+
print("CONFUSION MATRIX")
|
| 221 |
+
print("="*80)
|
| 222 |
+
cm = confusion_matrix(true_labels, pred_labels)
|
| 223 |
+
|
| 224 |
+
# Print confusion matrix with labels
|
| 225 |
+
print("\nTrue \\ Predicted", end="")
|
| 226 |
+
for i in range(num_labels):
|
| 227 |
+
print(f"\t{id2label[i]}", end="")
|
| 228 |
+
print()
|
| 229 |
+
|
| 230 |
+
for i in range(num_labels):
|
| 231 |
+
print(f"{id2label[i]:<15}", end="")
|
| 232 |
+
for j in range(num_labels):
|
| 233 |
+
print(f"\t{cm[i][j]}", end="")
|
| 234 |
+
print()
|
| 235 |
+
|
| 236 |
+
# Save confusion matrix to CSV
|
| 237 |
+
cm_df = pd.DataFrame(
|
| 238 |
+
cm,
|
| 239 |
+
index=[str(id2label[i]) for i in range(num_labels)],
|
| 240 |
+
columns=[str(id2label[i]) for i in range(num_labels)]
|
| 241 |
+
)
|
| 242 |
+
cm_df.to_csv("./confusion_matrix_marbertv2_cpt_ft.csv")
|
| 243 |
+
print("\nConfusion matrix saved to: ./confusion_matrix_marbertv2_cpt_ft.csv")
|
| 244 |
+
|
| 245 |
+
# -------------------------------------------------------------------
|
| 246 |
+
# 7. Show sample predictions
|
| 247 |
+
# -------------------------------------------------------------------
|
| 248 |
+
print("\n" + "="*80)
|
| 249 |
+
print("SAMPLE PREDICTIONS (CPT+Finetuned MARBERTv2)")
|
| 250 |
+
print("="*80)
|
| 251 |
+
|
| 252 |
+
# Show first 5 predictions
|
| 253 |
+
num_samples = min(5, len(test_df))
|
| 254 |
+
for i in range(num_samples):
|
| 255 |
+
print(f"\nSample {i+1}:")
|
| 256 |
+
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
|
| 257 |
+
if has_labels:
|
| 258 |
+
print(f"True Class: {test_df['Class'].iloc[i]}")
|
| 259 |
+
print(f"Predicted Class: {predicted_classes[i]}")
|
| 260 |
+
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
|
| 261 |
+
print(f"Probabilities:")
|
| 262 |
+
for idx, class_name in id2label.items():
|
| 263 |
+
print(f" Class {class_name}: {all_probabilities[i][idx]:.4f}")
|
| 264 |
+
|
| 265 |
+
print("\n" + "="*80)
|
| 266 |
+
print("Inference completed successfully!")
|
| 267 |
+
print(f"Model used: CPT+Finetuned MARBERTv2 from {MODEL_DIR}")
|
| 268 |
+
print("="*80)
|
Code/train_arabert.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Fine-tune aubmindlab/bert-large-arabertv2 for Arabic telecom customer comment classification.
|
| 6 |
+
|
| 7 |
+
Dataset (CSV):
|
| 8 |
+
/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv
|
| 9 |
+
|
| 10 |
+
Columns:
|
| 11 |
+
Commentaire client: str (text)
|
| 12 |
+
Class: int (label - values 1 through 9)
|
| 13 |
+
|
| 14 |
+
Model:
|
| 15 |
+
- AraBERTv2 Large encoder
|
| 16 |
+
- Classification head for multi-class prediction (9 classes)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from inspect import signature
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 26 |
+
from transformers import (
|
| 27 |
+
AutoTokenizer,
|
| 28 |
+
AutoModelForSequenceClassification,
|
| 29 |
+
TrainingArguments,
|
| 30 |
+
Trainer,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Slight speed boost on Ampere GPUs
|
| 34 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 35 |
+
torch.set_float32_matmul_precision("high")
|
| 36 |
+
|
| 37 |
+
# -------------------------------------------------------------------
|
| 38 |
+
# 1. Paths & config
|
| 39 |
+
# -------------------------------------------------------------------
|
| 40 |
+
DATA_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv"
|
| 41 |
+
|
| 42 |
+
MODEL_NAME = "aubmindlab/bert-large-arabertv2"
|
| 43 |
+
OUTPUT_DIR = "./telecom_arabert_final"
|
| 44 |
+
|
| 45 |
+
MAX_LENGTH = 512
|
| 46 |
+
|
| 47 |
+
# Define label mapping - classes are 1-9
|
| 48 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 49 |
+
ID2LABEL = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
|
| 50 |
+
NUM_LABELS = 9
|
| 51 |
+
|
| 52 |
+
# -------------------------------------------------------------------
|
| 53 |
+
# 2. Dataset loading
|
| 54 |
+
# -------------------------------------------------------------------
|
| 55 |
+
print("Loading telecom dataset from CSV...")
|
| 56 |
+
dataset = load_dataset(
|
| 57 |
+
"csv",
|
| 58 |
+
data_files=DATA_FILE,
|
| 59 |
+
split="train",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
print("Sample example:", dataset[0])
|
| 63 |
+
print(f"Total examples: {len(dataset)}")
|
| 64 |
+
|
| 65 |
+
print(f"Number of classes: {NUM_LABELS}")
|
| 66 |
+
print("Label mapping (class -> model index):", LABEL2ID)
|
| 67 |
+
print("Inverse mapping (model index -> class):", ID2LABEL)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_labels(example):
|
| 71 |
+
"""Convert class (1-9) to model label index (0-8)."""
|
| 72 |
+
class_val = example["Class"]
|
| 73 |
+
|
| 74 |
+
# Handle both int and string types
|
| 75 |
+
if isinstance(class_val, str):
|
| 76 |
+
class_val = int(class_val)
|
| 77 |
+
|
| 78 |
+
if class_val not in LABEL2ID:
|
| 79 |
+
raise ValueError(f"Unknown class: {class_val}. Expected 1-9.")
|
| 80 |
+
|
| 81 |
+
example["labels"] = LABEL2ID[class_val]
|
| 82 |
+
return example
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
dataset = dataset.map(encode_labels)
|
| 86 |
+
|
| 87 |
+
# Train/val split (90/10)
|
| 88 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 89 |
+
train_dataset = dataset["train"]
|
| 90 |
+
eval_dataset = dataset["test"]
|
| 91 |
+
|
| 92 |
+
print("Train size:", len(train_dataset))
|
| 93 |
+
print("Eval size:", len(eval_dataset))
|
| 94 |
+
|
| 95 |
+
# -------------------------------------------------------------------
|
| 96 |
+
# 3. Tokenization
|
| 97 |
+
# -------------------------------------------------------------------
|
| 98 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def preprocess_function(examples):
|
| 102 |
+
return tokenizer(
|
| 103 |
+
examples["Commentaire client"],
|
| 104 |
+
padding="max_length",
|
| 105 |
+
truncation=True,
|
| 106 |
+
max_length=MAX_LENGTH,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 111 |
+
eval_dataset = eval_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 112 |
+
|
| 113 |
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 114 |
+
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 115 |
+
|
| 116 |
+
# -------------------------------------------------------------------
|
| 117 |
+
# 4. Model - Using AutoModelForSequenceClassification
|
| 118 |
+
# -------------------------------------------------------------------
|
| 119 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 120 |
+
MODEL_NAME,
|
| 121 |
+
num_labels=NUM_LABELS,
|
| 122 |
+
id2label=ID2LABEL,
|
| 123 |
+
label2id=LABEL2ID,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
print("Model initialized with classification head")
|
| 127 |
+
print(f"Number of labels: {NUM_LABELS}")
|
| 128 |
+
print(f"Classes: {list(ID2LABEL.values())}")
|
| 129 |
+
|
| 130 |
+
# -------------------------------------------------------------------
|
| 131 |
+
# 5. Metrics
|
| 132 |
+
# -------------------------------------------------------------------
|
| 133 |
+
def compute_metrics(eval_pred):
|
| 134 |
+
logits, labels = eval_pred
|
| 135 |
+
predictions = np.argmax(logits, axis=-1)
|
| 136 |
+
|
| 137 |
+
# Overall metrics
|
| 138 |
+
accuracy = accuracy_score(labels, predictions)
|
| 139 |
+
|
| 140 |
+
# Weighted average (accounts for class imbalance)
|
| 141 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 142 |
+
labels, predictions, average='weighted', zero_division=0
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Macro average (treats all classes equally)
|
| 146 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 147 |
+
labels, predictions, average='macro', zero_division=0
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
metrics = {
|
| 151 |
+
'accuracy': accuracy,
|
| 152 |
+
'f1_weighted': f1_w,
|
| 153 |
+
'f1_macro': f1_m,
|
| 154 |
+
'precision_weighted': precision_w,
|
| 155 |
+
'recall_weighted': recall_w,
|
| 156 |
+
'precision_macro': precision_m,
|
| 157 |
+
'recall_macro': recall_m,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Per-class F1 scores
|
| 161 |
+
per_class_f1 = f1_score(labels, predictions, average=None, zero_division=0)
|
| 162 |
+
for idx in range(NUM_LABELS):
|
| 163 |
+
class_name = ID2LABEL[idx]
|
| 164 |
+
if idx < len(per_class_f1):
|
| 165 |
+
metrics[f'f1_class_{class_name}'] = per_class_f1[idx]
|
| 166 |
+
|
| 167 |
+
return metrics
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------------------------------------------------
|
| 171 |
+
# 6. TrainingArguments (old/new transformers compatible)
|
| 172 |
+
# -------------------------------------------------------------------
|
| 173 |
+
ta_sig = signature(TrainingArguments.__init__)
|
| 174 |
+
ta_params = set(ta_sig.parameters.keys())
|
| 175 |
+
|
| 176 |
+
is_bf16_supported = (
|
| 177 |
+
torch.cuda.is_available()
|
| 178 |
+
and hasattr(torch.cuda, "is_bf16_supported")
|
| 179 |
+
and torch.cuda.is_bf16_supported()
|
| 180 |
+
)
|
| 181 |
+
use_bf16 = bool(is_bf16_supported)
|
| 182 |
+
use_fp16 = not use_bf16
|
| 183 |
+
|
| 184 |
+
print(f"bf16 supported: {is_bf16_supported} -> using bf16={use_bf16}, fp16={use_fp16}")
|
| 185 |
+
|
| 186 |
+
base_kwargs = {
|
| 187 |
+
"output_dir": OUTPUT_DIR,
|
| 188 |
+
"num_train_epochs": 10,
|
| 189 |
+
"per_device_train_batch_size": 32,
|
| 190 |
+
"per_device_eval_batch_size": 64,
|
| 191 |
+
"learning_rate": 1e-3,
|
| 192 |
+
"weight_decay": 0.03,
|
| 193 |
+
"warmup_ratio": 0.1,
|
| 194 |
+
"logging_steps": 50,
|
| 195 |
+
"save_total_limit": 2,
|
| 196 |
+
"dataloader_num_workers": 4,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# Mixed precision flags if supported
|
| 200 |
+
if "bf16" in ta_params:
|
| 201 |
+
base_kwargs["bf16"] = use_bf16
|
| 202 |
+
if "fp16" in ta_params:
|
| 203 |
+
base_kwargs["fp16"] = use_fp16
|
| 204 |
+
|
| 205 |
+
# Handle evaluation_strategy compatibility
|
| 206 |
+
if "evaluation_strategy" in ta_params:
|
| 207 |
+
base_kwargs["evaluation_strategy"] = "epoch"
|
| 208 |
+
if "save_strategy" in ta_params:
|
| 209 |
+
base_kwargs["save_strategy"] = "epoch"
|
| 210 |
+
if "logging_strategy" in ta_params:
|
| 211 |
+
base_kwargs["logging_strategy"] = "steps"
|
| 212 |
+
if "load_best_model_at_end" in ta_params:
|
| 213 |
+
base_kwargs["load_best_model_at_end"] = True
|
| 214 |
+
if "metric_for_best_model" in ta_params:
|
| 215 |
+
base_kwargs["metric_for_best_model"] = "f1_weighted"
|
| 216 |
+
if "greater_is_better" in ta_params:
|
| 217 |
+
base_kwargs["greater_is_better"] = True
|
| 218 |
+
if "report_to" in ta_params:
|
| 219 |
+
base_kwargs["report_to"] = "none"
|
| 220 |
+
else:
|
| 221 |
+
if "report_to" in ta_params:
|
| 222 |
+
base_kwargs["report_to"] = "none"
|
| 223 |
+
print("[TrainingArguments] Old transformers version: no evaluation_strategy argument. Using simple setup.")
|
| 224 |
+
|
| 225 |
+
filtered_kwargs = {}
|
| 226 |
+
for k, v in base_kwargs.items():
|
| 227 |
+
if k in ta_params:
|
| 228 |
+
filtered_kwargs[k] = v
|
| 229 |
+
else:
|
| 230 |
+
print(f"[TrainingArguments] Skipping unsupported arg: {k}={v}")
|
| 231 |
+
|
| 232 |
+
training_args = TrainingArguments(**filtered_kwargs)
|
| 233 |
+
|
| 234 |
+
# -------------------------------------------------------------------
|
| 235 |
+
# 7. Trainer
|
| 236 |
+
# -------------------------------------------------------------------
|
| 237 |
+
trainer = Trainer(
|
| 238 |
+
model=model,
|
| 239 |
+
args=training_args,
|
| 240 |
+
train_dataset=train_dataset,
|
| 241 |
+
eval_dataset=eval_dataset,
|
| 242 |
+
tokenizer=tokenizer,
|
| 243 |
+
compute_metrics=compute_metrics,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# -------------------------------------------------------------------
|
| 247 |
+
# 8. Train & eval
|
| 248 |
+
# -------------------------------------------------------------------
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
print("Starting telecom classification training with AraBERT...")
|
| 251 |
+
trainer.train()
|
| 252 |
+
|
| 253 |
+
print("Evaluating on validation split...")
|
| 254 |
+
metrics = trainer.evaluate()
|
| 255 |
+
print("Validation metrics:", metrics)
|
| 256 |
+
|
| 257 |
+
print("Saving final model & tokenizer...")
|
| 258 |
+
trainer.save_model(OUTPUT_DIR)
|
| 259 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 260 |
+
|
| 261 |
+
print(f"Label mappings saved in config:")
|
| 262 |
+
print(f" ID to Label: {ID2LABEL}")
|
| 263 |
+
print(f" Label to ID: {LABEL2ID}")
|
| 264 |
+
|
| 265 |
+
# Quick sanity-check inference
|
| 266 |
+
example_texts = [
|
| 267 |
+
"الخدمة ممتازة جدا وسريعة",
|
| 268 |
+
"سيء للغاية ولا يستجيبون",
|
| 269 |
+
"متوسط الجودة"
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
inputs = tokenizer(
|
| 273 |
+
example_texts,
|
| 274 |
+
return_tensors="pt",
|
| 275 |
+
padding=True,
|
| 276 |
+
truncation=True,
|
| 277 |
+
max_length=MAX_LENGTH
|
| 278 |
+
).to(model.device)
|
| 279 |
+
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
outputs = model(**inputs)
|
| 282 |
+
|
| 283 |
+
logits = outputs.logits.cpu().numpy()
|
| 284 |
+
predictions = np.argmax(logits, axis=-1)
|
| 285 |
+
|
| 286 |
+
print("\nSanity-check predictions:")
|
| 287 |
+
for text, pred_idx in zip(example_texts, predictions):
|
| 288 |
+
pred_class = ID2LABEL[pred_idx]
|
| 289 |
+
print(f"Text: {text}")
|
| 290 |
+
print(f" -> Predicted Class: {pred_class}")
|
| 291 |
+
print()
|
| 292 |
+
|
| 293 |
+
print("Training complete and model saved to:", OUTPUT_DIR)
|
Code/train_dziribert.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Fine-tune alger-ia/dziribert for Arabic telecom customer comment classification.
|
| 6 |
+
|
| 7 |
+
Dataset (CSV):
|
| 8 |
+
/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv
|
| 9 |
+
|
| 10 |
+
Columns:
|
| 11 |
+
Commentaire client: str (text)
|
| 12 |
+
Class: int (label - values 1 through 9)
|
| 13 |
+
|
| 14 |
+
Model:
|
| 15 |
+
- DziriBERT encoder (alger-ia/dziribert)
|
| 16 |
+
- Classification head for multi-class prediction (9 classes)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from inspect import signature
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 26 |
+
from transformers import (
|
| 27 |
+
AutoTokenizer,
|
| 28 |
+
AutoModelForSequenceClassification,
|
| 29 |
+
TrainingArguments,
|
| 30 |
+
Trainer,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Slight speed boost on Ampere GPUs
|
| 34 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 35 |
+
torch.set_float32_matmul_precision("high")
|
| 36 |
+
|
| 37 |
+
# -------------------------------------------------------------------
|
| 38 |
+
# 1. Paths & config
|
| 39 |
+
# -------------------------------------------------------------------
|
| 40 |
+
DATA_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
|
| 41 |
+
|
| 42 |
+
MODEL_NAME = "alger-ia/dziribert"
|
| 43 |
+
OUTPUT_DIR = "./telecom_dziribert_final"
|
| 44 |
+
|
| 45 |
+
MAX_LENGTH = 512
|
| 46 |
+
|
| 47 |
+
# Define label mapping - classes are 1-9
|
| 48 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 49 |
+
ID2LABEL = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
|
| 50 |
+
NUM_LABELS = 9
|
| 51 |
+
|
| 52 |
+
# -------------------------------------------------------------------
|
| 53 |
+
# 2. Dataset loading
|
| 54 |
+
# -------------------------------------------------------------------
|
| 55 |
+
print("Loading telecom dataset from CSV...")
|
| 56 |
+
dataset = load_dataset(
|
| 57 |
+
"csv",
|
| 58 |
+
data_files=DATA_FILE,
|
| 59 |
+
split="train",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
print("Sample example:", dataset[0])
|
| 63 |
+
print(f"Total examples: {len(dataset)}")
|
| 64 |
+
|
| 65 |
+
print(f"Number of classes: {NUM_LABELS}")
|
| 66 |
+
print("Label mapping (class -> model index):", LABEL2ID)
|
| 67 |
+
print("Inverse mapping (model index -> class):", ID2LABEL)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_labels(example):
|
| 71 |
+
"""Convert class (1-9) to model label index (0-8)."""
|
| 72 |
+
class_val = example["Class"]
|
| 73 |
+
|
| 74 |
+
# Handle both int and string types
|
| 75 |
+
if isinstance(class_val, str):
|
| 76 |
+
class_val = int(class_val)
|
| 77 |
+
|
| 78 |
+
if class_val not in LABEL2ID:
|
| 79 |
+
raise ValueError(f"Unknown class: {class_val}. Expected 1-9.")
|
| 80 |
+
|
| 81 |
+
example["labels"] = LABEL2ID[class_val]
|
| 82 |
+
return example
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
dataset = dataset.map(encode_labels)
|
| 86 |
+
|
| 87 |
+
# Train/val split (90/10)
|
| 88 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 89 |
+
train_dataset = dataset["train"]
|
| 90 |
+
eval_dataset = dataset["test"]
|
| 91 |
+
|
| 92 |
+
print("Train size:", len(train_dataset))
|
| 93 |
+
print("Eval size:", len(eval_dataset))
|
| 94 |
+
|
| 95 |
+
# -------------------------------------------------------------------
|
| 96 |
+
# 3. Tokenization
|
| 97 |
+
# -------------------------------------------------------------------
|
| 98 |
+
print(f"\nLoading tokenizer for: {MODEL_NAME}")
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def preprocess_function(examples):
|
| 103 |
+
return tokenizer(
|
| 104 |
+
examples["Commentaire client"],
|
| 105 |
+
padding="max_length",
|
| 106 |
+
truncation=True,
|
| 107 |
+
max_length=MAX_LENGTH,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 112 |
+
eval_dataset = eval_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 113 |
+
|
| 114 |
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 115 |
+
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 116 |
+
|
| 117 |
+
# -------------------------------------------------------------------
|
| 118 |
+
# 4. Model - Using AutoModelForSequenceClassification
|
| 119 |
+
# -------------------------------------------------------------------
|
| 120 |
+
print(f"\nLoading model: {MODEL_NAME}")
|
| 121 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 122 |
+
MODEL_NAME,
|
| 123 |
+
num_labels=NUM_LABELS,
|
| 124 |
+
id2label=ID2LABEL,
|
| 125 |
+
label2id=LABEL2ID,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print("Model initialized with classification head")
|
| 129 |
+
print(f"Number of labels: {NUM_LABELS}")
|
| 130 |
+
print(f"Classes: {list(ID2LABEL.values())}")
|
| 131 |
+
|
| 132 |
+
# -------------------------------------------------------------------
|
| 133 |
+
# 5. Metrics
|
| 134 |
+
# -------------------------------------------------------------------
|
| 135 |
+
def compute_metrics(eval_pred):
|
| 136 |
+
logits, labels = eval_pred
|
| 137 |
+
predictions = np.argmax(logits, axis=-1)
|
| 138 |
+
|
| 139 |
+
# Overall metrics
|
| 140 |
+
accuracy = accuracy_score(labels, predictions)
|
| 141 |
+
|
| 142 |
+
# Weighted average (accounts for class imbalance)
|
| 143 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 144 |
+
labels, predictions, average='weighted', zero_division=0
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Macro average (treats all classes equally)
|
| 148 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 149 |
+
labels, predictions, average='macro', zero_division=0
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
metrics = {
|
| 153 |
+
'accuracy': accuracy,
|
| 154 |
+
'f1_weighted': f1_w,
|
| 155 |
+
'f1_macro': f1_m,
|
| 156 |
+
'precision_weighted': precision_w,
|
| 157 |
+
'recall_weighted': recall_w,
|
| 158 |
+
'precision_macro': precision_m,
|
| 159 |
+
'recall_macro': recall_m,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Per-class F1 scores
|
| 163 |
+
per_class_f1 = f1_score(labels, predictions, average=None, zero_division=0)
|
| 164 |
+
for idx in range(NUM_LABELS):
|
| 165 |
+
class_name = ID2LABEL[idx]
|
| 166 |
+
if idx < len(per_class_f1):
|
| 167 |
+
metrics[f'f1_class_{class_name}'] = per_class_f1[idx]
|
| 168 |
+
|
| 169 |
+
return metrics
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# -------------------------------------------------------------------
|
| 173 |
+
# 6. TrainingArguments (old/new transformers compatible)
|
| 174 |
+
# -------------------------------------------------------------------
|
| 175 |
+
ta_sig = signature(TrainingArguments.__init__)
|
| 176 |
+
ta_params = set(ta_sig.parameters.keys())
|
| 177 |
+
|
| 178 |
+
is_bf16_supported = (
|
| 179 |
+
torch.cuda.is_available()
|
| 180 |
+
and hasattr(torch.cuda, "is_bf16_supported")
|
| 181 |
+
and torch.cuda.is_bf16_supported()
|
| 182 |
+
)
|
| 183 |
+
use_bf16 = bool(is_bf16_supported)
|
| 184 |
+
use_fp16 = not use_bf16
|
| 185 |
+
|
| 186 |
+
print(f"bf16 supported: {is_bf16_supported} -> using bf16={use_bf16}, fp16={use_fp16}")
|
| 187 |
+
|
| 188 |
+
base_kwargs = {
|
| 189 |
+
"output_dir": OUTPUT_DIR,
|
| 190 |
+
"num_train_epochs": 50,
|
| 191 |
+
"per_device_train_batch_size": 32,
|
| 192 |
+
"per_device_eval_batch_size": 64,
|
| 193 |
+
"learning_rate": 1e-5,
|
| 194 |
+
"weight_decay": 0.01,
|
| 195 |
+
"warmup_ratio": 0.1,
|
| 196 |
+
"logging_steps": 50,
|
| 197 |
+
"save_total_limit": 2,
|
| 198 |
+
"dataloader_num_workers": 4,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
# Mixed precision flags if supported
|
| 202 |
+
if "bf16" in ta_params:
|
| 203 |
+
base_kwargs["bf16"] = use_bf16
|
| 204 |
+
if "fp16" in ta_params:
|
| 205 |
+
base_kwargs["fp16"] = use_fp16
|
| 206 |
+
|
| 207 |
+
# Handle evaluation_strategy compatibility
|
| 208 |
+
if "evaluation_strategy" in ta_params:
|
| 209 |
+
base_kwargs["evaluation_strategy"] = "epoch"
|
| 210 |
+
if "save_strategy" in ta_params:
|
| 211 |
+
base_kwargs["save_strategy"] = "epoch"
|
| 212 |
+
if "logging_strategy" in ta_params:
|
| 213 |
+
base_kwargs["logging_strategy"] = "steps"
|
| 214 |
+
if "load_best_model_at_end" in ta_params:
|
| 215 |
+
base_kwargs["load_best_model_at_end"] = True
|
| 216 |
+
if "metric_for_best_model" in ta_params:
|
| 217 |
+
base_kwargs["metric_for_best_model"] = "f1_weighted"
|
| 218 |
+
if "greater_is_better" in ta_params:
|
| 219 |
+
base_kwargs["greater_is_better"] = True
|
| 220 |
+
if "report_to" in ta_params:
|
| 221 |
+
base_kwargs["report_to"] = "none"
|
| 222 |
+
else:
|
| 223 |
+
if "report_to" in ta_params:
|
| 224 |
+
base_kwargs["report_to"] = "none"
|
| 225 |
+
print("[TrainingArguments] Old transformers version: no evaluation_strategy argument. Using simple setup.")
|
| 226 |
+
|
| 227 |
+
filtered_kwargs = {}
|
| 228 |
+
for k, v in base_kwargs.items():
|
| 229 |
+
if k in ta_params:
|
| 230 |
+
filtered_kwargs[k] = v
|
| 231 |
+
else:
|
| 232 |
+
print(f"[TrainingArguments] Skipping unsupported arg: {k}={v}")
|
| 233 |
+
|
| 234 |
+
training_args = TrainingArguments(**filtered_kwargs)
|
| 235 |
+
|
| 236 |
+
# -------------------------------------------------------------------
|
| 237 |
+
# 7. Trainer
|
| 238 |
+
# -------------------------------------------------------------------
|
| 239 |
+
trainer = Trainer(
|
| 240 |
+
model=model,
|
| 241 |
+
args=training_args,
|
| 242 |
+
train_dataset=train_dataset,
|
| 243 |
+
eval_dataset=eval_dataset,
|
| 244 |
+
tokenizer=tokenizer,
|
| 245 |
+
compute_metrics=compute_metrics,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# -------------------------------------------------------------------
|
| 249 |
+
# 8. Train & eval
|
| 250 |
+
# -------------------------------------------------------------------
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
print("\nStarting DziriBERT telecom classification training...")
|
| 253 |
+
trainer.train()
|
| 254 |
+
|
| 255 |
+
print("\nEvaluating on validation split...")
|
| 256 |
+
metrics = trainer.evaluate()
|
| 257 |
+
print("Validation metrics:", metrics)
|
| 258 |
+
|
| 259 |
+
print("\nSaving final model & tokenizer...")
|
| 260 |
+
trainer.save_model(OUTPUT_DIR)
|
| 261 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 262 |
+
|
| 263 |
+
# Save label mappings to config
|
| 264 |
+
import json
|
| 265 |
+
config_path = os.path.join(OUTPUT_DIR, "config.json")
|
| 266 |
+
if os.path.exists(config_path):
|
| 267 |
+
with open(config_path, 'r') as f:
|
| 268 |
+
config_data = json.load(f)
|
| 269 |
+
|
| 270 |
+
# Update with actual label mappings
|
| 271 |
+
config_data['id2label'] = {str(idx): str(label) for idx, label in ID2LABEL.items()}
|
| 272 |
+
config_data['label2id'] = {str(label): idx for label, idx in LABEL2ID.items()}
|
| 273 |
+
config_data['num_labels'] = NUM_LABELS
|
| 274 |
+
config_data['problem_type'] = 'single_label_classification'
|
| 275 |
+
|
| 276 |
+
with open(config_path, 'w') as f:
|
| 277 |
+
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
| 278 |
+
|
| 279 |
+
print(f"Saved label mappings to config: {ID2LABEL}")
|
| 280 |
+
|
| 281 |
+
print(f"\nLabel mappings:")
|
| 282 |
+
print(f" ID to Label: {ID2LABEL}")
|
| 283 |
+
print(f" Label to ID: {LABEL2ID}")
|
| 284 |
+
|
| 285 |
+
# Quick sanity-check inference
|
| 286 |
+
example_texts = [
|
| 287 |
+
"الخدمة ممتازة جدا وسريعة",
|
| 288 |
+
"سيء للغاية ولا يستجيبون",
|
| 289 |
+
"متوسط الجودة"
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
inputs = tokenizer(
|
| 293 |
+
example_texts,
|
| 294 |
+
return_tensors="pt",
|
| 295 |
+
padding=True,
|
| 296 |
+
truncation=True,
|
| 297 |
+
max_length=MAX_LENGTH
|
| 298 |
+
).to(model.device)
|
| 299 |
+
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
outputs = model(**inputs)
|
| 302 |
+
|
| 303 |
+
logits = outputs.logits.cpu().numpy()
|
| 304 |
+
predictions = np.argmax(logits, axis=-1)
|
| 305 |
+
|
| 306 |
+
print("\nSanity-check predictions:")
|
| 307 |
+
for text, pred_idx in zip(example_texts, predictions):
|
| 308 |
+
pred_class = ID2LABEL[pred_idx]
|
| 309 |
+
print(f"Text: {text}")
|
| 310 |
+
print(f" -> Predicted Class: {pred_class}")
|
| 311 |
+
print()
|
| 312 |
+
|
| 313 |
+
print("Training complete and model saved to:", OUTPUT_DIR)
|
Code/train_marbertv2.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Fine-tune UBC-NLP/MARBERTv2 for Arabic telecom customer comment classification.
|
| 6 |
+
|
| 7 |
+
Dataset (CSV):
|
| 8 |
+
/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/train.csv
|
| 9 |
+
|
| 10 |
+
Columns:
|
| 11 |
+
Commentaire client: str (text)
|
| 12 |
+
Class: int (label - values 1 through 9)
|
| 13 |
+
|
| 14 |
+
Model:
|
| 15 |
+
- MARBERTv2 encoder
|
| 16 |
+
- Classification head for multi-class prediction (9 classes)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from inspect import signature
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| 26 |
+
from transformers import (
|
| 27 |
+
AutoTokenizer,
|
| 28 |
+
AutoModelForSequenceClassification,
|
| 29 |
+
TrainingArguments,
|
| 30 |
+
Trainer,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Slight speed boost on Ampere GPUs
|
| 34 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 35 |
+
torch.set_float32_matmul_precision("high")
|
| 36 |
+
|
| 37 |
+
# -------------------------------------------------------------------
|
| 38 |
+
# 1. Paths & config
|
| 39 |
+
# -------------------------------------------------------------------
|
| 40 |
+
DATA_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--labelds/snapshots/48f016fd5987875b0e9f79d0689cef2ec3b2ce0b/train.csv"
|
| 41 |
+
|
| 42 |
+
MODEL_NAME = "UBC-NLP/MARBERTv2"
|
| 43 |
+
OUTPUT_DIR = "./telecom_marbertv2_final"
|
| 44 |
+
|
| 45 |
+
MAX_LENGTH = 256
|
| 46 |
+
|
| 47 |
+
# Define label mapping - classes are 1-9
|
| 48 |
+
LABEL2ID = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8}
|
| 49 |
+
ID2LABEL = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9}
|
| 50 |
+
NUM_LABELS = 9
|
| 51 |
+
|
| 52 |
+
# -------------------------------------------------------------------
|
| 53 |
+
# 2. Dataset loading
|
| 54 |
+
# -------------------------------------------------------------------
|
| 55 |
+
print("Loading telecom dataset from CSV...")
|
| 56 |
+
dataset = load_dataset(
|
| 57 |
+
"csv",
|
| 58 |
+
data_files=DATA_FILE,
|
| 59 |
+
split="train",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
print("Sample example:", dataset[0])
|
| 63 |
+
print(f"Total examples: {len(dataset)}")
|
| 64 |
+
|
| 65 |
+
print(f"Number of classes: {NUM_LABELS}")
|
| 66 |
+
print("Label mapping (class -> model index):", LABEL2ID)
|
| 67 |
+
print("Inverse mapping (model index -> class):", ID2LABEL)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_labels(example):
|
| 71 |
+
"""Convert class (1-9) to model label index (0-8)."""
|
| 72 |
+
class_val = example["Class"]
|
| 73 |
+
|
| 74 |
+
# Handle both int and string types
|
| 75 |
+
if isinstance(class_val, str):
|
| 76 |
+
class_val = int(class_val)
|
| 77 |
+
|
| 78 |
+
if class_val not in LABEL2ID:
|
| 79 |
+
raise ValueError(f"Unknown class: {class_val}. Expected 1-9.")
|
| 80 |
+
|
| 81 |
+
example["labels"] = LABEL2ID[class_val]
|
| 82 |
+
return example
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
dataset = dataset.map(encode_labels)
|
| 86 |
+
|
| 87 |
+
# Train/val split (90/10)
|
| 88 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 89 |
+
train_dataset = dataset["train"]
|
| 90 |
+
eval_dataset = dataset["test"]
|
| 91 |
+
|
| 92 |
+
print("Train size:", len(train_dataset))
|
| 93 |
+
print("Eval size:", len(eval_dataset))
|
| 94 |
+
|
| 95 |
+
# -------------------------------------------------------------------
|
| 96 |
+
# 3. Tokenization
|
| 97 |
+
# -------------------------------------------------------------------
|
| 98 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def preprocess_function(examples):
|
| 102 |
+
return tokenizer(
|
| 103 |
+
examples["Commentaire client"],
|
| 104 |
+
padding="max_length",
|
| 105 |
+
truncation=True,
|
| 106 |
+
max_length=MAX_LENGTH,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 111 |
+
eval_dataset = eval_dataset.map(preprocess_function, batched=True, num_proc=4)
|
| 112 |
+
|
| 113 |
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 114 |
+
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 115 |
+
|
| 116 |
+
# -------------------------------------------------------------------
|
| 117 |
+
# 4. Model - Using AutoModelForSequenceClassification
|
| 118 |
+
# -------------------------------------------------------------------
|
| 119 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 120 |
+
MODEL_NAME,
|
| 121 |
+
num_labels=NUM_LABELS,
|
| 122 |
+
id2label=ID2LABEL,
|
| 123 |
+
label2id=LABEL2ID,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
print("Model initialized with classification head")
|
| 127 |
+
print(f"Number of labels: {NUM_LABELS}")
|
| 128 |
+
print(f"Classes: {list(ID2LABEL.values())}")
|
| 129 |
+
|
| 130 |
+
# -------------------------------------------------------------------
|
| 131 |
+
# 5. Metrics
|
| 132 |
+
# -------------------------------------------------------------------
|
| 133 |
+
def compute_metrics(eval_pred):
|
| 134 |
+
logits, labels = eval_pred
|
| 135 |
+
predictions = np.argmax(logits, axis=-1)
|
| 136 |
+
|
| 137 |
+
# Overall metrics
|
| 138 |
+
accuracy = accuracy_score(labels, predictions)
|
| 139 |
+
|
| 140 |
+
# Weighted average (accounts for class imbalance)
|
| 141 |
+
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
|
| 142 |
+
labels, predictions, average='weighted', zero_division=0
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Macro average (treats all classes equally)
|
| 146 |
+
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
|
| 147 |
+
labels, predictions, average='macro', zero_division=0
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
metrics = {
|
| 151 |
+
'accuracy': accuracy,
|
| 152 |
+
'f1_weighted': f1_w,
|
| 153 |
+
'f1_macro': f1_m,
|
| 154 |
+
'precision_weighted': precision_w,
|
| 155 |
+
'recall_weighted': recall_w,
|
| 156 |
+
'precision_macro': precision_m,
|
| 157 |
+
'recall_macro': recall_m,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Per-class F1 scores
|
| 161 |
+
per_class_f1 = f1_score(labels, predictions, average=None, zero_division=0)
|
| 162 |
+
for idx in range(NUM_LABELS):
|
| 163 |
+
class_name = ID2LABEL[idx]
|
| 164 |
+
if idx < len(per_class_f1):
|
| 165 |
+
metrics[f'f1_class_{class_name}'] = per_class_f1[idx]
|
| 166 |
+
|
| 167 |
+
return metrics
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------------------------------------------------
|
| 171 |
+
# 6. TrainingArguments (old/new transformers compatible)
|
| 172 |
+
# -------------------------------------------------------------------
|
| 173 |
+
ta_sig = signature(TrainingArguments.__init__)
|
| 174 |
+
ta_params = set(ta_sig.parameters.keys())
|
| 175 |
+
|
| 176 |
+
is_bf16_supported = (
|
| 177 |
+
torch.cuda.is_available()
|
| 178 |
+
and hasattr(torch.cuda, "is_bf16_supported")
|
| 179 |
+
and torch.cuda.is_bf16_supported()
|
| 180 |
+
)
|
| 181 |
+
use_bf16 = bool(is_bf16_supported)
|
| 182 |
+
use_fp16 = not use_bf16
|
| 183 |
+
|
| 184 |
+
print(f"bf16 supported: {is_bf16_supported} -> using bf16={use_bf16}, fp16={use_fp16}")
|
| 185 |
+
|
| 186 |
+
base_kwargs = {
|
| 187 |
+
"output_dir": OUTPUT_DIR,
|
| 188 |
+
"num_train_epochs": 10,
|
| 189 |
+
"per_device_train_batch_size": 32,
|
| 190 |
+
"per_device_eval_batch_size": 64,
|
| 191 |
+
"learning_rate": 1e-4,
|
| 192 |
+
"weight_decay": 0.02,
|
| 193 |
+
"warmup_ratio": 0.1,
|
| 194 |
+
"logging_steps": 50,
|
| 195 |
+
"save_total_limit": 2,
|
| 196 |
+
"dataloader_num_workers": 4,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# Mixed precision flags if supported
|
| 200 |
+
if "bf16" in ta_params:
|
| 201 |
+
base_kwargs["bf16"] = use_bf16
|
| 202 |
+
if "fp16" in ta_params:
|
| 203 |
+
base_kwargs["fp16"] = use_fp16
|
| 204 |
+
|
| 205 |
+
# Handle evaluation_strategy compatibility
|
| 206 |
+
if "evaluation_strategy" in ta_params:
|
| 207 |
+
base_kwargs["evaluation_strategy"] = "epoch"
|
| 208 |
+
if "save_strategy" in ta_params:
|
| 209 |
+
base_kwargs["save_strategy"] = "epoch"
|
| 210 |
+
if "logging_strategy" in ta_params:
|
| 211 |
+
base_kwargs["logging_strategy"] = "steps"
|
| 212 |
+
if "load_best_model_at_end" in ta_params:
|
| 213 |
+
base_kwargs["load_best_model_at_end"] = True
|
| 214 |
+
if "metric_for_best_model" in ta_params:
|
| 215 |
+
base_kwargs["metric_for_best_model"] = "f1_weighted"
|
| 216 |
+
if "greater_is_better" in ta_params:
|
| 217 |
+
base_kwargs["greater_is_better"] = True
|
| 218 |
+
if "report_to" in ta_params:
|
| 219 |
+
base_kwargs["report_to"] = "none"
|
| 220 |
+
else:
|
| 221 |
+
if "report_to" in ta_params:
|
| 222 |
+
base_kwargs["report_to"] = "none"
|
| 223 |
+
print("[TrainingArguments] Old transformers version: no evaluation_strategy argument. Using simple setup.")
|
| 224 |
+
|
| 225 |
+
filtered_kwargs = {}
|
| 226 |
+
for k, v in base_kwargs.items():
|
| 227 |
+
if k in ta_params:
|
| 228 |
+
filtered_kwargs[k] = v
|
| 229 |
+
else:
|
| 230 |
+
print(f"[TrainingArguments] Skipping unsupported arg: {k}={v}")
|
| 231 |
+
|
| 232 |
+
training_args = TrainingArguments(**filtered_kwargs)
|
| 233 |
+
|
| 234 |
+
# -------------------------------------------------------------------
|
| 235 |
+
# 7. Trainer
|
| 236 |
+
# -------------------------------------------------------------------
|
| 237 |
+
trainer = Trainer(
|
| 238 |
+
model=model,
|
| 239 |
+
args=training_args,
|
| 240 |
+
train_dataset=train_dataset,
|
| 241 |
+
eval_dataset=eval_dataset,
|
| 242 |
+
tokenizer=tokenizer,
|
| 243 |
+
compute_metrics=compute_metrics,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# -------------------------------------------------------------------
|
| 247 |
+
# 8. Train & eval
|
| 248 |
+
# -------------------------------------------------------------------
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
print("Starting telecom classification training...")
|
| 251 |
+
trainer.train()
|
| 252 |
+
|
| 253 |
+
print("Evaluating on validation split...")
|
| 254 |
+
metrics = trainer.evaluate()
|
| 255 |
+
print("Validation metrics:", metrics)
|
| 256 |
+
|
| 257 |
+
print("Saving final model & tokenizer...")
|
| 258 |
+
trainer.save_model(OUTPUT_DIR)
|
| 259 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 260 |
+
|
| 261 |
+
print(f"Label mappings saved in config:")
|
| 262 |
+
print(f" ID to Label: {ID2LABEL}")
|
| 263 |
+
print(f" Label to ID: {LABEL2ID}")
|
| 264 |
+
|
| 265 |
+
# Quick sanity-check inference
|
| 266 |
+
example_texts = [
|
| 267 |
+
"الخدمة ممتازة جدا وسريعة",
|
| 268 |
+
"سيء للغاية ولا يستجيبون",
|
| 269 |
+
"متوسط الجودة"
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
inputs = tokenizer(
|
| 273 |
+
example_texts,
|
| 274 |
+
return_tensors="pt",
|
| 275 |
+
padding=True,
|
| 276 |
+
truncation=True,
|
| 277 |
+
max_length=MAX_LENGTH
|
| 278 |
+
).to(model.device)
|
| 279 |
+
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
outputs = model(**inputs)
|
| 282 |
+
|
| 283 |
+
logits = outputs.logits.cpu().numpy()
|
| 284 |
+
predictions = np.argmax(logits, axis=-1)
|
| 285 |
+
|
| 286 |
+
print("\nSanity-check predictions:")
|
| 287 |
+
for text, pred_idx in zip(example_texts, predictions):
|
| 288 |
+
pred_class = ID2LABEL[pred_idx]
|
| 289 |
+
print(f"Text: {text}")
|
| 290 |
+
print(f" -> Predicted Class: {pred_class}")
|
| 291 |
+
print()
|
| 292 |
+
|
| 293 |
+
print("Training complete and model saved to:", OUTPUT_DIR)
|
Code/voting.ipynb
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "f4d51cb4",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"id": "344fbcef",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"df_gemma1 = pd.read_csv('test_predictions_gemma3.csv')\n",
|
| 21 |
+
"df_camel = pd.read_csv('test_predictions_camelbert_cpt_ftbeforesleep.csv')\n",
|
| 22 |
+
"df_arabert = pd.read_csv('test_predictions_arabert_full_pipeline.csv')\n",
|
| 23 |
+
"df_dziribert = pd.read_csv('test_predictions_dziribert.csv')\n",
|
| 24 |
+
"df_marbert= pd.read_csv('test_predictions_marbertv2_cpt_ft5.csv')\n",
|
| 25 |
+
"df_gemma2 = pd.read_csv('test_predictions_gemma3-2nd.csv')"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"id": "29486bd8",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"# Perform soft voting with weights\n",
|
| 36 |
+
"# Give more weight to df_gemma1 (weight = 2), others get weight = 1\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Define weights for each model\n",
|
| 39 |
+
"weights = {\n",
|
| 40 |
+
" 'df_gemma1': 2.0,\n",
|
| 41 |
+
" 'df_camel': 1.0,\n",
|
| 42 |
+
" 'df_arabert': 1.0,\n",
|
| 43 |
+
" 'df_dziribert': 1.0,\n",
|
| 44 |
+
" 'df_gemma2': 1.0,\n",
|
| 45 |
+
" 'df_marbert': 1.0\n",
|
| 46 |
+
"}\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"# Create a combined dataframe with id and Commentaire client from the first dataframe\n",
|
| 49 |
+
"result_df = df_gemma1[['id', 'Réseau Social', 'Commentaire client']].copy()\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# Initialize a dictionary to store weighted vote counts for each class\n",
|
| 52 |
+
"from collections import defaultdict\n",
|
| 53 |
+
"import numpy as np\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"# For each row, calculate weighted votes\n",
|
| 56 |
+
"final_predictions = []\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"for idx in range(len(df_gemma1)):\n",
|
| 59 |
+
" vote_counts = defaultdict(float)\n",
|
| 60 |
+
" \n",
|
| 61 |
+
" # Add weighted votes from each model\n",
|
| 62 |
+
" vote_counts[df_gemma1.iloc[idx]['Predicted_Class']] += weights['df_gemma1']\n",
|
| 63 |
+
" vote_counts[df_camel.iloc[idx]['Predicted_Class']] += weights['df_camel']\n",
|
| 64 |
+
" vote_counts[df_arabert.iloc[idx]['Predicted_Class']] += weights['df_arabert']\n",
|
| 65 |
+
" vote_counts[df_dziribert.iloc[idx]['Predicted_Class']] += weights['df_dziribert']\n",
|
| 66 |
+
" vote_counts[df_gemma2.iloc[idx]['Predicted_Class']] += weights['df_gemma2']\n",
|
| 67 |
+
" vote_counts[df_marbert.iloc[idx]['Predicted_Class']] += weights['df_marbert']\n",
|
| 68 |
+
" \n",
|
| 69 |
+
" # Select class with highest weighted vote\n",
|
| 70 |
+
" final_prediction = max(vote_counts.items(), key=lambda x: x[1])[0]\n",
|
| 71 |
+
" final_predictions.append(final_prediction)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# Add predictions to result dataframe\n",
|
| 74 |
+
"result_df['Predicted_Class'] = final_predictions\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Display statistics\n",
|
| 77 |
+
"print(f\"Total samples: {len(result_df)}\")\n",
|
| 78 |
+
"print(f\"\\nClass distribution:\")\n",
|
| 79 |
+
"print(result_df['Predicted_Class'].value_counts().sort_index())\n",
|
| 80 |
+
"print(f\"\\nWeight configuration:\")\n",
|
| 81 |
+
"for model, weight in weights.items():\n",
|
| 82 |
+
" print(f\" {model}: {weight}\")\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# Display first few rows\n",
|
| 85 |
+
"print(f\"\\nFirst 5 predictions:\")\n",
|
| 86 |
+
"result_df.head()\n"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"execution_count": null,
|
| 92 |
+
"id": "543f2936",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"# Save results to CSV (only id and Class)\n",
|
| 97 |
+
"output_filename = 'test_predictions_weighted_voting_ensemble.csv'\n",
|
| 98 |
+
"output_df = result_df[['id', 'Predicted_Class']].copy()\n",
|
| 99 |
+
"output_df.rename(columns={'Predicted_Class': 'Class'}, inplace=True)\n",
|
| 100 |
+
"output_df.to_csv(output_filename, index=False)\n",
|
| 101 |
+
"print(f\"Results saved to: {output_filename}\")\n",
|
| 102 |
+
"print(f\"Columns in output: id, Class\")\n"
|
| 103 |
+
]
|
| 104 |
+
}
|
| 105 |
+
],
|
| 106 |
+
"metadata": {
|
| 107 |
+
"kernelspec": {
|
| 108 |
+
"display_name": ".venv",
|
| 109 |
+
"language": "python",
|
| 110 |
+
"name": "python3"
|
| 111 |
+
},
|
| 112 |
+
"language_info": {
|
| 113 |
+
"name": "python",
|
| 114 |
+
"version": "3.12.3"
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"nbformat": 4,
|
| 118 |
+
"nbformat_minor": 5
|
| 119 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# Requirements for Arabic LLM Training & Inference
|
| 3 |
+
# ============================================================================
|
| 4 |
+
# Install: pip install -r requirements.txt
|
| 5 |
+
# ============================================================================
|
| 6 |
+
|
| 7 |
+
# Core ML Frameworks
|
| 8 |
+
torch>=2.1.0
|
| 9 |
+
transformers>=4.45.0
|
| 10 |
+
accelerate>=0.25.0
|
| 11 |
+
datasets>=2.16.0
|
| 12 |
+
|
| 13 |
+
# Parameter-Efficient Fine-Tuning
|
| 14 |
+
peft>=0.7.0
|
| 15 |
+
|
| 16 |
+
# Tokenizers
|
| 17 |
+
tokenizers>=0.15.0
|
| 18 |
+
sentencepiece>=0.1.99
|
| 19 |
+
|
| 20 |
+
# Data Processing
|
| 21 |
+
pandas>=2.0.0
|
| 22 |
+
numpy>=1.24.0
|
| 23 |
+
scikit-learn>=1.3.0
|
| 24 |
+
|
| 25 |
+
# Progress & Logging
|
| 26 |
+
tqdm>=4.66.0
|
| 27 |
+
|
| 28 |
+
# Evaluation Metrics
|
| 29 |
+
evaluate>=0.4.0
|
| 30 |
+
seqeval>=1.2.2
|
| 31 |
+
|
| 32 |
+
# Optional: Weights & Biases for experiment tracking
|
| 33 |
+
# wandb>=0.16.0
|
| 34 |
+
|
| 35 |
+
# Optional: Flash Attention 2 (requires CUDA)
|
| 36 |
+
# flash-attn>=2.3.0
|
| 37 |
+
|
| 38 |
+
# Optional: BitsAndBytes for quantization
|
| 39 |
+
# bitsandbytes>=0.41.0
|
| 40 |
+
|
| 41 |
+
# Jupyter Support (optional)
|
| 42 |
+
# jupyter>=1.0.0
|
| 43 |
+
# ipykernel>=6.25.0
|
| 44 |
+
|
| 45 |
+
# Hugging Face Hub
|
| 46 |
+
huggingface-hub>=0.19.0
|
| 47 |
+
|
| 48 |
+
# SafeTensors for efficient model loading
|
| 49 |
+
safetensors>=0.4.0
|
| 50 |
+
|
| 51 |
+
# Regex for text preprocessing
|
| 52 |
+
regex>=2023.10.0
|
| 53 |
+
|
| 54 |
+
|
telecom_camelbert_cpt_ft/checkpoint-3500/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "1",
|
| 14 |
+
"1": "2",
|
| 15 |
+
"2": "3",
|
| 16 |
+
"3": "4",
|
| 17 |
+
"4": "5",
|
| 18 |
+
"5": "6",
|
| 19 |
+
"6": "7",
|
| 20 |
+
"7": "8",
|
| 21 |
+
"8": "9"
|
| 22 |
+
},
|
| 23 |
+
"initializer_range": 0.02,
|
| 24 |
+
"intermediate_size": 3072,
|
| 25 |
+
"label2id": {
|
| 26 |
+
"0": 1,
|
| 27 |
+
"1": 2,
|
| 28 |
+
"2": 3,
|
| 29 |
+
"3": 4,
|
| 30 |
+
"4": 5,
|
| 31 |
+
"5": 6,
|
| 32 |
+
"6": 7,
|
| 33 |
+
"7": 8,
|
| 34 |
+
"8": 9
|
| 35 |
+
},
|
| 36 |
+
"layer_norm_eps": 1e-12,
|
| 37 |
+
"max_position_embeddings": 512,
|
| 38 |
+
"model_type": "bert",
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"pad_token_id": 0,
|
| 42 |
+
"position_embedding_type": "absolute",
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"transformers_version": "4.57.3",
|
| 45 |
+
"type_vocab_size": 2,
|
| 46 |
+
"use_cache": true,
|
| 47 |
+
"vocab_size": 30000
|
| 48 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3500/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2a6a9d45f9658c97abc75e064c8148f35c6b1cd6f9740d212d898d79b77f591
|
| 3 |
+
size 436376588
|
telecom_camelbert_cpt_ft/checkpoint-3500/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c98dfe00bf6e8edda2b26395a77170bcec1a4cfaa60b9bf81982603ae218a218
|
| 3 |
+
size 872877451
|
telecom_camelbert_cpt_ft/checkpoint-3500/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:16c4b35aee0bbaeee4ff40ea0ec57a691c91ae4126d7f2611b0621dcea397d18
|
| 3 |
+
size 14645
|
telecom_camelbert_cpt_ft/checkpoint-3500/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fba3938712732ea9a2d4d892486d311187e6b050bac015163f3186f8c0ea04f4
|
| 3 |
+
size 1465
|
telecom_camelbert_cpt_ft/checkpoint-3500/special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3500/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_camelbert_cpt_ft/checkpoint-3500/tokenizer_config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": false,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"full_tokenizer_file": null,
|
| 50 |
+
"mask_token": "[MASK]",
|
| 51 |
+
"max_length": 512,
|
| 52 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 53 |
+
"never_split": null,
|
| 54 |
+
"pad_to_multiple_of": null,
|
| 55 |
+
"pad_token": "[PAD]",
|
| 56 |
+
"pad_token_type_id": 0,
|
| 57 |
+
"padding_side": "right",
|
| 58 |
+
"sep_token": "[SEP]",
|
| 59 |
+
"stride": 0,
|
| 60 |
+
"strip_accents": null,
|
| 61 |
+
"tokenize_chinese_chars": true,
|
| 62 |
+
"tokenizer_class": "BertTokenizer",
|
| 63 |
+
"truncation_side": "right",
|
| 64 |
+
"truncation_strategy": "longest_first",
|
| 65 |
+
"unk_token": "[UNK]"
|
| 66 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3500/trainer_state.json
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 68.62745098039215,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 3500,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.9803921568627451,
|
| 14 |
+
"grad_norm": 2.6230194568634033,
|
| 15 |
+
"learning_rate": 2.7450980392156867e-06,
|
| 16 |
+
"loss": 2.2283,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 1.9607843137254903,
|
| 21 |
+
"grad_norm": 3.465364933013916,
|
| 22 |
+
"learning_rate": 5.546218487394959e-06,
|
| 23 |
+
"loss": 2.1335,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 2.9411764705882355,
|
| 28 |
+
"grad_norm": 4.075287818908691,
|
| 29 |
+
"learning_rate": 8.34733893557423e-06,
|
| 30 |
+
"loss": 1.8493,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 3.9215686274509802,
|
| 35 |
+
"grad_norm": 4.183288097381592,
|
| 36 |
+
"learning_rate": 1.1148459383753503e-05,
|
| 37 |
+
"loss": 1.3306,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 4.901960784313726,
|
| 42 |
+
"grad_norm": 6.412795066833496,
|
| 43 |
+
"learning_rate": 1.3949579831932774e-05,
|
| 44 |
+
"loss": 0.8739,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 5.882352941176471,
|
| 49 |
+
"grad_norm": 9.716330528259277,
|
| 50 |
+
"learning_rate": 1.6750700280112046e-05,
|
| 51 |
+
"loss": 0.6337,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 6.862745098039216,
|
| 56 |
+
"grad_norm": 6.941835403442383,
|
| 57 |
+
"learning_rate": 1.9551820728291318e-05,
|
| 58 |
+
"loss": 0.4314,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 7.8431372549019605,
|
| 63 |
+
"grad_norm": 3.7662336826324463,
|
| 64 |
+
"learning_rate": 1.973856209150327e-05,
|
| 65 |
+
"loss": 0.2944,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 8.823529411764707,
|
| 70 |
+
"grad_norm": 4.981897354125977,
|
| 71 |
+
"learning_rate": 1.9427326486150017e-05,
|
| 72 |
+
"loss": 0.1731,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 9.803921568627452,
|
| 77 |
+
"grad_norm": 9.845929145812988,
|
| 78 |
+
"learning_rate": 1.9116090880796766e-05,
|
| 79 |
+
"loss": 0.1273,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 10.784313725490197,
|
| 84 |
+
"grad_norm": 8.743799209594727,
|
| 85 |
+
"learning_rate": 1.880485527544351e-05,
|
| 86 |
+
"loss": 0.077,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 11.764705882352942,
|
| 91 |
+
"grad_norm": 0.8690173625946045,
|
| 92 |
+
"learning_rate": 1.849361967009026e-05,
|
| 93 |
+
"loss": 0.0603,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 12.745098039215687,
|
| 98 |
+
"grad_norm": 0.22123952209949493,
|
| 99 |
+
"learning_rate": 1.8182384064737007e-05,
|
| 100 |
+
"loss": 0.0304,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 13.72549019607843,
|
| 105 |
+
"grad_norm": 0.263492226600647,
|
| 106 |
+
"learning_rate": 1.7871148459383755e-05,
|
| 107 |
+
"loss": 0.0263,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 14.705882352941176,
|
| 112 |
+
"grad_norm": 0.9709652066230774,
|
| 113 |
+
"learning_rate": 1.7559912854030504e-05,
|
| 114 |
+
"loss": 0.0307,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 15.686274509803921,
|
| 119 |
+
"grad_norm": 0.03851320594549179,
|
| 120 |
+
"learning_rate": 1.724867724867725e-05,
|
| 121 |
+
"loss": 0.0237,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 16.666666666666668,
|
| 126 |
+
"grad_norm": 3.8813209533691406,
|
| 127 |
+
"learning_rate": 1.6937441643323997e-05,
|
| 128 |
+
"loss": 0.0244,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 17.647058823529413,
|
| 133 |
+
"grad_norm": 0.14205335080623627,
|
| 134 |
+
"learning_rate": 1.6626206037970745e-05,
|
| 135 |
+
"loss": 0.0103,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 18.627450980392158,
|
| 140 |
+
"grad_norm": 0.03730874881148338,
|
| 141 |
+
"learning_rate": 1.631497043261749e-05,
|
| 142 |
+
"loss": 0.0123,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 19.607843137254903,
|
| 147 |
+
"grad_norm": 0.047113899141550064,
|
| 148 |
+
"learning_rate": 1.6003734827264242e-05,
|
| 149 |
+
"loss": 0.0203,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 20.58823529411765,
|
| 154 |
+
"grad_norm": 0.019405728206038475,
|
| 155 |
+
"learning_rate": 1.5692499221910987e-05,
|
| 156 |
+
"loss": 0.0105,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 21.568627450980394,
|
| 161 |
+
"grad_norm": 0.022607196122407913,
|
| 162 |
+
"learning_rate": 1.5381263616557735e-05,
|
| 163 |
+
"loss": 0.0133,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 22.54901960784314,
|
| 168 |
+
"grad_norm": 0.14680148661136627,
|
| 169 |
+
"learning_rate": 1.5070028011204482e-05,
|
| 170 |
+
"loss": 0.012,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 23.529411764705884,
|
| 175 |
+
"grad_norm": 0.07145881652832031,
|
| 176 |
+
"learning_rate": 1.475879240585123e-05,
|
| 177 |
+
"loss": 0.011,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 24.509803921568626,
|
| 182 |
+
"grad_norm": 0.017874594777822495,
|
| 183 |
+
"learning_rate": 1.4447556800497977e-05,
|
| 184 |
+
"loss": 0.0121,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 25.49019607843137,
|
| 189 |
+
"grad_norm": 0.014184300787746906,
|
| 190 |
+
"learning_rate": 1.4136321195144727e-05,
|
| 191 |
+
"loss": 0.0082,
|
| 192 |
+
"step": 1300
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 26.470588235294116,
|
| 196 |
+
"grad_norm": 0.4630926847457886,
|
| 197 |
+
"learning_rate": 1.3825085589791474e-05,
|
| 198 |
+
"loss": 0.0089,
|
| 199 |
+
"step": 1350
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 27.45098039215686,
|
| 203 |
+
"grad_norm": 1.217696189880371,
|
| 204 |
+
"learning_rate": 1.351384998443822e-05,
|
| 205 |
+
"loss": 0.0037,
|
| 206 |
+
"step": 1400
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"epoch": 28.431372549019606,
|
| 210 |
+
"grad_norm": 0.036678045988082886,
|
| 211 |
+
"learning_rate": 1.3202614379084969e-05,
|
| 212 |
+
"loss": 0.0072,
|
| 213 |
+
"step": 1450
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"epoch": 29.41176470588235,
|
| 217 |
+
"grad_norm": 0.014452925883233547,
|
| 218 |
+
"learning_rate": 1.2891378773731715e-05,
|
| 219 |
+
"loss": 0.0047,
|
| 220 |
+
"step": 1500
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"epoch": 30.392156862745097,
|
| 224 |
+
"grad_norm": 0.016204573214054108,
|
| 225 |
+
"learning_rate": 1.2580143168378462e-05,
|
| 226 |
+
"loss": 0.0053,
|
| 227 |
+
"step": 1550
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"epoch": 31.372549019607842,
|
| 231 |
+
"grad_norm": 0.01764889620244503,
|
| 232 |
+
"learning_rate": 1.2268907563025212e-05,
|
| 233 |
+
"loss": 0.0049,
|
| 234 |
+
"step": 1600
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"epoch": 32.35294117647059,
|
| 238 |
+
"grad_norm": 0.01166362501680851,
|
| 239 |
+
"learning_rate": 1.1957671957671959e-05,
|
| 240 |
+
"loss": 0.0054,
|
| 241 |
+
"step": 1650
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"epoch": 33.333333333333336,
|
| 245 |
+
"grad_norm": 0.0751669779419899,
|
| 246 |
+
"learning_rate": 1.1646436352318707e-05,
|
| 247 |
+
"loss": 0.0048,
|
| 248 |
+
"step": 1700
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"epoch": 34.31372549019608,
|
| 252 |
+
"grad_norm": 0.01111924834549427,
|
| 253 |
+
"learning_rate": 1.1335200746965454e-05,
|
| 254 |
+
"loss": 0.0064,
|
| 255 |
+
"step": 1750
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"epoch": 35.294117647058826,
|
| 259 |
+
"grad_norm": 0.040332596749067307,
|
| 260 |
+
"learning_rate": 1.10239651416122e-05,
|
| 261 |
+
"loss": 0.0061,
|
| 262 |
+
"step": 1800
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"epoch": 36.27450980392157,
|
| 266 |
+
"grad_norm": 0.015221468172967434,
|
| 267 |
+
"learning_rate": 1.0712729536258948e-05,
|
| 268 |
+
"loss": 0.0051,
|
| 269 |
+
"step": 1850
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"epoch": 37.254901960784316,
|
| 273 |
+
"grad_norm": 0.010177787393331528,
|
| 274 |
+
"learning_rate": 1.0401493930905697e-05,
|
| 275 |
+
"loss": 0.0056,
|
| 276 |
+
"step": 1900
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"epoch": 38.23529411764706,
|
| 280 |
+
"grad_norm": 0.006494768429547548,
|
| 281 |
+
"learning_rate": 1.0090258325552445e-05,
|
| 282 |
+
"loss": 0.0038,
|
| 283 |
+
"step": 1950
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"epoch": 39.21568627450981,
|
| 287 |
+
"grad_norm": 0.021074198186397552,
|
| 288 |
+
"learning_rate": 9.779022720199192e-06,
|
| 289 |
+
"loss": 0.0037,
|
| 290 |
+
"step": 2000
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"epoch": 40.19607843137255,
|
| 294 |
+
"grad_norm": 0.7565334439277649,
|
| 295 |
+
"learning_rate": 9.467787114845938e-06,
|
| 296 |
+
"loss": 0.0075,
|
| 297 |
+
"step": 2050
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"epoch": 41.1764705882353,
|
| 301 |
+
"grad_norm": 0.007625663187354803,
|
| 302 |
+
"learning_rate": 9.156551509492687e-06,
|
| 303 |
+
"loss": 0.0025,
|
| 304 |
+
"step": 2100
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"epoch": 42.15686274509804,
|
| 308 |
+
"grad_norm": 0.006594022735953331,
|
| 309 |
+
"learning_rate": 8.845315904139435e-06,
|
| 310 |
+
"loss": 0.0041,
|
| 311 |
+
"step": 2150
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"epoch": 43.13725490196079,
|
| 315 |
+
"grad_norm": 0.5956095457077026,
|
| 316 |
+
"learning_rate": 8.534080298786182e-06,
|
| 317 |
+
"loss": 0.0052,
|
| 318 |
+
"step": 2200
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"epoch": 44.11764705882353,
|
| 322 |
+
"grad_norm": 0.005290526431053877,
|
| 323 |
+
"learning_rate": 8.22284469343293e-06,
|
| 324 |
+
"loss": 0.0037,
|
| 325 |
+
"step": 2250
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"epoch": 45.09803921568628,
|
| 329 |
+
"grad_norm": 0.12142275273799896,
|
| 330 |
+
"learning_rate": 7.911609088079677e-06,
|
| 331 |
+
"loss": 0.0049,
|
| 332 |
+
"step": 2300
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"epoch": 46.07843137254902,
|
| 336 |
+
"grad_norm": 0.512083888053894,
|
| 337 |
+
"learning_rate": 7.600373482726424e-06,
|
| 338 |
+
"loss": 0.0038,
|
| 339 |
+
"step": 2350
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"epoch": 47.05882352941177,
|
| 343 |
+
"grad_norm": 0.009066535159945488,
|
| 344 |
+
"learning_rate": 7.2891378773731725e-06,
|
| 345 |
+
"loss": 0.0038,
|
| 346 |
+
"step": 2400
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"epoch": 48.03921568627451,
|
| 350 |
+
"grad_norm": 0.0053984676487743855,
|
| 351 |
+
"learning_rate": 6.97790227201992e-06,
|
| 352 |
+
"loss": 0.0037,
|
| 353 |
+
"step": 2450
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"epoch": 49.01960784313726,
|
| 357 |
+
"grad_norm": 0.0097350450232625,
|
| 358 |
+
"learning_rate": 6.666666666666667e-06,
|
| 359 |
+
"loss": 0.0031,
|
| 360 |
+
"step": 2500
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 50.0,
|
| 364 |
+
"grad_norm": 0.00987960398197174,
|
| 365 |
+
"learning_rate": 6.355431061313415e-06,
|
| 366 |
+
"loss": 0.0039,
|
| 367 |
+
"step": 2550
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"epoch": 50.98039215686274,
|
| 371 |
+
"grad_norm": 0.2973344624042511,
|
| 372 |
+
"learning_rate": 6.0441954559601625e-06,
|
| 373 |
+
"loss": 0.0044,
|
| 374 |
+
"step": 2600
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"epoch": 51.96078431372549,
|
| 378 |
+
"grad_norm": 0.009365350939333439,
|
| 379 |
+
"learning_rate": 5.732959850606909e-06,
|
| 380 |
+
"loss": 0.0036,
|
| 381 |
+
"step": 2650
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"epoch": 52.94117647058823,
|
| 385 |
+
"grad_norm": 0.005942502524703741,
|
| 386 |
+
"learning_rate": 5.4217242452536574e-06,
|
| 387 |
+
"loss": 0.0038,
|
| 388 |
+
"step": 2700
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"epoch": 53.92156862745098,
|
| 392 |
+
"grad_norm": 0.008297664113342762,
|
| 393 |
+
"learning_rate": 5.110488639900405e-06,
|
| 394 |
+
"loss": 0.0034,
|
| 395 |
+
"step": 2750
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"epoch": 54.90196078431372,
|
| 399 |
+
"grad_norm": 0.7482114434242249,
|
| 400 |
+
"learning_rate": 4.799253034547152e-06,
|
| 401 |
+
"loss": 0.0037,
|
| 402 |
+
"step": 2800
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"epoch": 55.88235294117647,
|
| 406 |
+
"grad_norm": 0.00625429954379797,
|
| 407 |
+
"learning_rate": 4.4880174291939e-06,
|
| 408 |
+
"loss": 0.0026,
|
| 409 |
+
"step": 2850
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"epoch": 56.86274509803921,
|
| 413 |
+
"grad_norm": 0.008189293555915356,
|
| 414 |
+
"learning_rate": 4.176781823840647e-06,
|
| 415 |
+
"loss": 0.0036,
|
| 416 |
+
"step": 2900
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"epoch": 57.84313725490196,
|
| 420 |
+
"grad_norm": 0.006718316115438938,
|
| 421 |
+
"learning_rate": 3.865546218487396e-06,
|
| 422 |
+
"loss": 0.0035,
|
| 423 |
+
"step": 2950
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"epoch": 58.8235294117647,
|
| 427 |
+
"grad_norm": 0.024499941617250443,
|
| 428 |
+
"learning_rate": 3.5543106131341428e-06,
|
| 429 |
+
"loss": 0.0027,
|
| 430 |
+
"step": 3000
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"epoch": 59.80392156862745,
|
| 434 |
+
"grad_norm": 0.007647670805454254,
|
| 435 |
+
"learning_rate": 3.2430750077808903e-06,
|
| 436 |
+
"loss": 0.0039,
|
| 437 |
+
"step": 3050
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"epoch": 60.78431372549019,
|
| 441 |
+
"grad_norm": 0.0057996599934995174,
|
| 442 |
+
"learning_rate": 2.931839402427638e-06,
|
| 443 |
+
"loss": 0.0036,
|
| 444 |
+
"step": 3100
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"epoch": 61.76470588235294,
|
| 448 |
+
"grad_norm": 0.004913662560284138,
|
| 449 |
+
"learning_rate": 2.6206037970743852e-06,
|
| 450 |
+
"loss": 0.0038,
|
| 451 |
+
"step": 3150
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"epoch": 62.745098039215684,
|
| 455 |
+
"grad_norm": 0.004377233795821667,
|
| 456 |
+
"learning_rate": 2.309368191721133e-06,
|
| 457 |
+
"loss": 0.0021,
|
| 458 |
+
"step": 3200
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"epoch": 63.72549019607843,
|
| 462 |
+
"grad_norm": 0.00702957296743989,
|
| 463 |
+
"learning_rate": 1.9981325863678806e-06,
|
| 464 |
+
"loss": 0.0044,
|
| 465 |
+
"step": 3250
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"epoch": 64.70588235294117,
|
| 469 |
+
"grad_norm": 0.004603248089551926,
|
| 470 |
+
"learning_rate": 1.6868969810146283e-06,
|
| 471 |
+
"loss": 0.0039,
|
| 472 |
+
"step": 3300
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"epoch": 65.68627450980392,
|
| 476 |
+
"grad_norm": 0.009002058766782284,
|
| 477 |
+
"learning_rate": 1.3756613756613758e-06,
|
| 478 |
+
"loss": 0.0018,
|
| 479 |
+
"step": 3350
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"epoch": 66.66666666666667,
|
| 483 |
+
"grad_norm": 0.005587454419583082,
|
| 484 |
+
"learning_rate": 1.0644257703081233e-06,
|
| 485 |
+
"loss": 0.004,
|
| 486 |
+
"step": 3400
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"epoch": 67.6470588235294,
|
| 490 |
+
"grad_norm": 0.007010480388998985,
|
| 491 |
+
"learning_rate": 7.531901649548709e-07,
|
| 492 |
+
"loss": 0.0026,
|
| 493 |
+
"step": 3450
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"epoch": 68.62745098039215,
|
| 497 |
+
"grad_norm": 0.004527617245912552,
|
| 498 |
+
"learning_rate": 4.419545596016184e-07,
|
| 499 |
+
"loss": 0.0043,
|
| 500 |
+
"step": 3500
|
| 501 |
+
}
|
| 502 |
+
],
|
| 503 |
+
"logging_steps": 50,
|
| 504 |
+
"max_steps": 3570,
|
| 505 |
+
"num_input_tokens_seen": 0,
|
| 506 |
+
"num_train_epochs": 70,
|
| 507 |
+
"save_steps": 500,
|
| 508 |
+
"stateful_callbacks": {
|
| 509 |
+
"TrainerControl": {
|
| 510 |
+
"args": {
|
| 511 |
+
"should_epoch_stop": false,
|
| 512 |
+
"should_evaluate": false,
|
| 513 |
+
"should_log": false,
|
| 514 |
+
"should_save": true,
|
| 515 |
+
"should_training_stop": false
|
| 516 |
+
},
|
| 517 |
+
"attributes": {}
|
| 518 |
+
}
|
| 519 |
+
},
|
| 520 |
+
"total_flos": 2.909454409554739e+16,
|
| 521 |
+
"train_batch_size": 32,
|
| 522 |
+
"trial_name": null,
|
| 523 |
+
"trial_params": null
|
| 524 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3500/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b5f9538422ac3ff516fbd5394c01499e3df51a8944c66ec759cd64011b08382
|
| 3 |
+
size 5841
|
telecom_camelbert_cpt_ft/checkpoint-3500/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_camelbert_cpt_ft/checkpoint-3570/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "1",
|
| 14 |
+
"1": "2",
|
| 15 |
+
"2": "3",
|
| 16 |
+
"3": "4",
|
| 17 |
+
"4": "5",
|
| 18 |
+
"5": "6",
|
| 19 |
+
"6": "7",
|
| 20 |
+
"7": "8",
|
| 21 |
+
"8": "9"
|
| 22 |
+
},
|
| 23 |
+
"initializer_range": 0.02,
|
| 24 |
+
"intermediate_size": 3072,
|
| 25 |
+
"label2id": {
|
| 26 |
+
"0": 1,
|
| 27 |
+
"1": 2,
|
| 28 |
+
"2": 3,
|
| 29 |
+
"3": 4,
|
| 30 |
+
"4": 5,
|
| 31 |
+
"5": 6,
|
| 32 |
+
"6": 7,
|
| 33 |
+
"7": 8,
|
| 34 |
+
"8": 9
|
| 35 |
+
},
|
| 36 |
+
"layer_norm_eps": 1e-12,
|
| 37 |
+
"max_position_embeddings": 512,
|
| 38 |
+
"model_type": "bert",
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"pad_token_id": 0,
|
| 42 |
+
"position_embedding_type": "absolute",
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"transformers_version": "4.57.3",
|
| 45 |
+
"type_vocab_size": 2,
|
| 46 |
+
"use_cache": true,
|
| 47 |
+
"vocab_size": 30000
|
| 48 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3570/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fd4a9bdb6b95896f2d56f3c3b16e3a65a00020c2a659d7c25b29533d04b2ebe
|
| 3 |
+
size 436376588
|
telecom_camelbert_cpt_ft/checkpoint-3570/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea3f001a73b720b183d52fdf114366fb67b446a305b0ba6f34d49835c2d8efb7
|
| 3 |
+
size 872877451
|
telecom_camelbert_cpt_ft/checkpoint-3570/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f9c5ee45d5ae8a5fde2e385c98fb73e0e572ce4d71b29f086ecc189bb285563
|
| 3 |
+
size 14645
|
telecom_camelbert_cpt_ft/checkpoint-3570/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffad826366e2d47eed8cd887fe4778e4084716a2369d14b2d2b195e111ed1c29
|
| 3 |
+
size 1465
|
telecom_camelbert_cpt_ft/checkpoint-3570/special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3570/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_camelbert_cpt_ft/checkpoint-3570/tokenizer_config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": false,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"full_tokenizer_file": null,
|
| 50 |
+
"mask_token": "[MASK]",
|
| 51 |
+
"max_length": 512,
|
| 52 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 53 |
+
"never_split": null,
|
| 54 |
+
"pad_to_multiple_of": null,
|
| 55 |
+
"pad_token": "[PAD]",
|
| 56 |
+
"pad_token_type_id": 0,
|
| 57 |
+
"padding_side": "right",
|
| 58 |
+
"sep_token": "[SEP]",
|
| 59 |
+
"stride": 0,
|
| 60 |
+
"strip_accents": null,
|
| 61 |
+
"tokenize_chinese_chars": true,
|
| 62 |
+
"tokenizer_class": "BertTokenizer",
|
| 63 |
+
"truncation_side": "right",
|
| 64 |
+
"truncation_strategy": "longest_first",
|
| 65 |
+
"unk_token": "[UNK]"
|
| 66 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3570/trainer_state.json
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 70.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 3570,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.9803921568627451,
|
| 14 |
+
"grad_norm": 2.6230194568634033,
|
| 15 |
+
"learning_rate": 2.7450980392156867e-06,
|
| 16 |
+
"loss": 2.2283,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 1.9607843137254903,
|
| 21 |
+
"grad_norm": 3.465364933013916,
|
| 22 |
+
"learning_rate": 5.546218487394959e-06,
|
| 23 |
+
"loss": 2.1335,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 2.9411764705882355,
|
| 28 |
+
"grad_norm": 4.075287818908691,
|
| 29 |
+
"learning_rate": 8.34733893557423e-06,
|
| 30 |
+
"loss": 1.8493,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 3.9215686274509802,
|
| 35 |
+
"grad_norm": 4.183288097381592,
|
| 36 |
+
"learning_rate": 1.1148459383753503e-05,
|
| 37 |
+
"loss": 1.3306,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 4.901960784313726,
|
| 42 |
+
"grad_norm": 6.412795066833496,
|
| 43 |
+
"learning_rate": 1.3949579831932774e-05,
|
| 44 |
+
"loss": 0.8739,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 5.882352941176471,
|
| 49 |
+
"grad_norm": 9.716330528259277,
|
| 50 |
+
"learning_rate": 1.6750700280112046e-05,
|
| 51 |
+
"loss": 0.6337,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 6.862745098039216,
|
| 56 |
+
"grad_norm": 6.941835403442383,
|
| 57 |
+
"learning_rate": 1.9551820728291318e-05,
|
| 58 |
+
"loss": 0.4314,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 7.8431372549019605,
|
| 63 |
+
"grad_norm": 3.7662336826324463,
|
| 64 |
+
"learning_rate": 1.973856209150327e-05,
|
| 65 |
+
"loss": 0.2944,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 8.823529411764707,
|
| 70 |
+
"grad_norm": 4.981897354125977,
|
| 71 |
+
"learning_rate": 1.9427326486150017e-05,
|
| 72 |
+
"loss": 0.1731,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 9.803921568627452,
|
| 77 |
+
"grad_norm": 9.845929145812988,
|
| 78 |
+
"learning_rate": 1.9116090880796766e-05,
|
| 79 |
+
"loss": 0.1273,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 10.784313725490197,
|
| 84 |
+
"grad_norm": 8.743799209594727,
|
| 85 |
+
"learning_rate": 1.880485527544351e-05,
|
| 86 |
+
"loss": 0.077,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 11.764705882352942,
|
| 91 |
+
"grad_norm": 0.8690173625946045,
|
| 92 |
+
"learning_rate": 1.849361967009026e-05,
|
| 93 |
+
"loss": 0.0603,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 12.745098039215687,
|
| 98 |
+
"grad_norm": 0.22123952209949493,
|
| 99 |
+
"learning_rate": 1.8182384064737007e-05,
|
| 100 |
+
"loss": 0.0304,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 13.72549019607843,
|
| 105 |
+
"grad_norm": 0.263492226600647,
|
| 106 |
+
"learning_rate": 1.7871148459383755e-05,
|
| 107 |
+
"loss": 0.0263,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 14.705882352941176,
|
| 112 |
+
"grad_norm": 0.9709652066230774,
|
| 113 |
+
"learning_rate": 1.7559912854030504e-05,
|
| 114 |
+
"loss": 0.0307,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 15.686274509803921,
|
| 119 |
+
"grad_norm": 0.03851320594549179,
|
| 120 |
+
"learning_rate": 1.724867724867725e-05,
|
| 121 |
+
"loss": 0.0237,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 16.666666666666668,
|
| 126 |
+
"grad_norm": 3.8813209533691406,
|
| 127 |
+
"learning_rate": 1.6937441643323997e-05,
|
| 128 |
+
"loss": 0.0244,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 17.647058823529413,
|
| 133 |
+
"grad_norm": 0.14205335080623627,
|
| 134 |
+
"learning_rate": 1.6626206037970745e-05,
|
| 135 |
+
"loss": 0.0103,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 18.627450980392158,
|
| 140 |
+
"grad_norm": 0.03730874881148338,
|
| 141 |
+
"learning_rate": 1.631497043261749e-05,
|
| 142 |
+
"loss": 0.0123,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 19.607843137254903,
|
| 147 |
+
"grad_norm": 0.047113899141550064,
|
| 148 |
+
"learning_rate": 1.6003734827264242e-05,
|
| 149 |
+
"loss": 0.0203,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 20.58823529411765,
|
| 154 |
+
"grad_norm": 0.019405728206038475,
|
| 155 |
+
"learning_rate": 1.5692499221910987e-05,
|
| 156 |
+
"loss": 0.0105,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 21.568627450980394,
|
| 161 |
+
"grad_norm": 0.022607196122407913,
|
| 162 |
+
"learning_rate": 1.5381263616557735e-05,
|
| 163 |
+
"loss": 0.0133,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 22.54901960784314,
|
| 168 |
+
"grad_norm": 0.14680148661136627,
|
| 169 |
+
"learning_rate": 1.5070028011204482e-05,
|
| 170 |
+
"loss": 0.012,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 23.529411764705884,
|
| 175 |
+
"grad_norm": 0.07145881652832031,
|
| 176 |
+
"learning_rate": 1.475879240585123e-05,
|
| 177 |
+
"loss": 0.011,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 24.509803921568626,
|
| 182 |
+
"grad_norm": 0.017874594777822495,
|
| 183 |
+
"learning_rate": 1.4447556800497977e-05,
|
| 184 |
+
"loss": 0.0121,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 25.49019607843137,
|
| 189 |
+
"grad_norm": 0.014184300787746906,
|
| 190 |
+
"learning_rate": 1.4136321195144727e-05,
|
| 191 |
+
"loss": 0.0082,
|
| 192 |
+
"step": 1300
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 26.470588235294116,
|
| 196 |
+
"grad_norm": 0.4630926847457886,
|
| 197 |
+
"learning_rate": 1.3825085589791474e-05,
|
| 198 |
+
"loss": 0.0089,
|
| 199 |
+
"step": 1350
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 27.45098039215686,
|
| 203 |
+
"grad_norm": 1.217696189880371,
|
| 204 |
+
"learning_rate": 1.351384998443822e-05,
|
| 205 |
+
"loss": 0.0037,
|
| 206 |
+
"step": 1400
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"epoch": 28.431372549019606,
|
| 210 |
+
"grad_norm": 0.036678045988082886,
|
| 211 |
+
"learning_rate": 1.3202614379084969e-05,
|
| 212 |
+
"loss": 0.0072,
|
| 213 |
+
"step": 1450
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"epoch": 29.41176470588235,
|
| 217 |
+
"grad_norm": 0.014452925883233547,
|
| 218 |
+
"learning_rate": 1.2891378773731715e-05,
|
| 219 |
+
"loss": 0.0047,
|
| 220 |
+
"step": 1500
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"epoch": 30.392156862745097,
|
| 224 |
+
"grad_norm": 0.016204573214054108,
|
| 225 |
+
"learning_rate": 1.2580143168378462e-05,
|
| 226 |
+
"loss": 0.0053,
|
| 227 |
+
"step": 1550
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"epoch": 31.372549019607842,
|
| 231 |
+
"grad_norm": 0.01764889620244503,
|
| 232 |
+
"learning_rate": 1.2268907563025212e-05,
|
| 233 |
+
"loss": 0.0049,
|
| 234 |
+
"step": 1600
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"epoch": 32.35294117647059,
|
| 238 |
+
"grad_norm": 0.01166362501680851,
|
| 239 |
+
"learning_rate": 1.1957671957671959e-05,
|
| 240 |
+
"loss": 0.0054,
|
| 241 |
+
"step": 1650
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"epoch": 33.333333333333336,
|
| 245 |
+
"grad_norm": 0.0751669779419899,
|
| 246 |
+
"learning_rate": 1.1646436352318707e-05,
|
| 247 |
+
"loss": 0.0048,
|
| 248 |
+
"step": 1700
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"epoch": 34.31372549019608,
|
| 252 |
+
"grad_norm": 0.01111924834549427,
|
| 253 |
+
"learning_rate": 1.1335200746965454e-05,
|
| 254 |
+
"loss": 0.0064,
|
| 255 |
+
"step": 1750
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"epoch": 35.294117647058826,
|
| 259 |
+
"grad_norm": 0.040332596749067307,
|
| 260 |
+
"learning_rate": 1.10239651416122e-05,
|
| 261 |
+
"loss": 0.0061,
|
| 262 |
+
"step": 1800
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"epoch": 36.27450980392157,
|
| 266 |
+
"grad_norm": 0.015221468172967434,
|
| 267 |
+
"learning_rate": 1.0712729536258948e-05,
|
| 268 |
+
"loss": 0.0051,
|
| 269 |
+
"step": 1850
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"epoch": 37.254901960784316,
|
| 273 |
+
"grad_norm": 0.010177787393331528,
|
| 274 |
+
"learning_rate": 1.0401493930905697e-05,
|
| 275 |
+
"loss": 0.0056,
|
| 276 |
+
"step": 1900
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"epoch": 38.23529411764706,
|
| 280 |
+
"grad_norm": 0.006494768429547548,
|
| 281 |
+
"learning_rate": 1.0090258325552445e-05,
|
| 282 |
+
"loss": 0.0038,
|
| 283 |
+
"step": 1950
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"epoch": 39.21568627450981,
|
| 287 |
+
"grad_norm": 0.021074198186397552,
|
| 288 |
+
"learning_rate": 9.779022720199192e-06,
|
| 289 |
+
"loss": 0.0037,
|
| 290 |
+
"step": 2000
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"epoch": 40.19607843137255,
|
| 294 |
+
"grad_norm": 0.7565334439277649,
|
| 295 |
+
"learning_rate": 9.467787114845938e-06,
|
| 296 |
+
"loss": 0.0075,
|
| 297 |
+
"step": 2050
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"epoch": 41.1764705882353,
|
| 301 |
+
"grad_norm": 0.007625663187354803,
|
| 302 |
+
"learning_rate": 9.156551509492687e-06,
|
| 303 |
+
"loss": 0.0025,
|
| 304 |
+
"step": 2100
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"epoch": 42.15686274509804,
|
| 308 |
+
"grad_norm": 0.006594022735953331,
|
| 309 |
+
"learning_rate": 8.845315904139435e-06,
|
| 310 |
+
"loss": 0.0041,
|
| 311 |
+
"step": 2150
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"epoch": 43.13725490196079,
|
| 315 |
+
"grad_norm": 0.5956095457077026,
|
| 316 |
+
"learning_rate": 8.534080298786182e-06,
|
| 317 |
+
"loss": 0.0052,
|
| 318 |
+
"step": 2200
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"epoch": 44.11764705882353,
|
| 322 |
+
"grad_norm": 0.005290526431053877,
|
| 323 |
+
"learning_rate": 8.22284469343293e-06,
|
| 324 |
+
"loss": 0.0037,
|
| 325 |
+
"step": 2250
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"epoch": 45.09803921568628,
|
| 329 |
+
"grad_norm": 0.12142275273799896,
|
| 330 |
+
"learning_rate": 7.911609088079677e-06,
|
| 331 |
+
"loss": 0.0049,
|
| 332 |
+
"step": 2300
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"epoch": 46.07843137254902,
|
| 336 |
+
"grad_norm": 0.512083888053894,
|
| 337 |
+
"learning_rate": 7.600373482726424e-06,
|
| 338 |
+
"loss": 0.0038,
|
| 339 |
+
"step": 2350
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"epoch": 47.05882352941177,
|
| 343 |
+
"grad_norm": 0.009066535159945488,
|
| 344 |
+
"learning_rate": 7.2891378773731725e-06,
|
| 345 |
+
"loss": 0.0038,
|
| 346 |
+
"step": 2400
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"epoch": 48.03921568627451,
|
| 350 |
+
"grad_norm": 0.0053984676487743855,
|
| 351 |
+
"learning_rate": 6.97790227201992e-06,
|
| 352 |
+
"loss": 0.0037,
|
| 353 |
+
"step": 2450
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"epoch": 49.01960784313726,
|
| 357 |
+
"grad_norm": 0.0097350450232625,
|
| 358 |
+
"learning_rate": 6.666666666666667e-06,
|
| 359 |
+
"loss": 0.0031,
|
| 360 |
+
"step": 2500
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 50.0,
|
| 364 |
+
"grad_norm": 0.00987960398197174,
|
| 365 |
+
"learning_rate": 6.355431061313415e-06,
|
| 366 |
+
"loss": 0.0039,
|
| 367 |
+
"step": 2550
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"epoch": 50.98039215686274,
|
| 371 |
+
"grad_norm": 0.2973344624042511,
|
| 372 |
+
"learning_rate": 6.0441954559601625e-06,
|
| 373 |
+
"loss": 0.0044,
|
| 374 |
+
"step": 2600
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"epoch": 51.96078431372549,
|
| 378 |
+
"grad_norm": 0.009365350939333439,
|
| 379 |
+
"learning_rate": 5.732959850606909e-06,
|
| 380 |
+
"loss": 0.0036,
|
| 381 |
+
"step": 2650
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"epoch": 52.94117647058823,
|
| 385 |
+
"grad_norm": 0.005942502524703741,
|
| 386 |
+
"learning_rate": 5.4217242452536574e-06,
|
| 387 |
+
"loss": 0.0038,
|
| 388 |
+
"step": 2700
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"epoch": 53.92156862745098,
|
| 392 |
+
"grad_norm": 0.008297664113342762,
|
| 393 |
+
"learning_rate": 5.110488639900405e-06,
|
| 394 |
+
"loss": 0.0034,
|
| 395 |
+
"step": 2750
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"epoch": 54.90196078431372,
|
| 399 |
+
"grad_norm": 0.7482114434242249,
|
| 400 |
+
"learning_rate": 4.799253034547152e-06,
|
| 401 |
+
"loss": 0.0037,
|
| 402 |
+
"step": 2800
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"epoch": 55.88235294117647,
|
| 406 |
+
"grad_norm": 0.00625429954379797,
|
| 407 |
+
"learning_rate": 4.4880174291939e-06,
|
| 408 |
+
"loss": 0.0026,
|
| 409 |
+
"step": 2850
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"epoch": 56.86274509803921,
|
| 413 |
+
"grad_norm": 0.008189293555915356,
|
| 414 |
+
"learning_rate": 4.176781823840647e-06,
|
| 415 |
+
"loss": 0.0036,
|
| 416 |
+
"step": 2900
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"epoch": 57.84313725490196,
|
| 420 |
+
"grad_norm": 0.006718316115438938,
|
| 421 |
+
"learning_rate": 3.865546218487396e-06,
|
| 422 |
+
"loss": 0.0035,
|
| 423 |
+
"step": 2950
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"epoch": 58.8235294117647,
|
| 427 |
+
"grad_norm": 0.024499941617250443,
|
| 428 |
+
"learning_rate": 3.5543106131341428e-06,
|
| 429 |
+
"loss": 0.0027,
|
| 430 |
+
"step": 3000
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"epoch": 59.80392156862745,
|
| 434 |
+
"grad_norm": 0.007647670805454254,
|
| 435 |
+
"learning_rate": 3.2430750077808903e-06,
|
| 436 |
+
"loss": 0.0039,
|
| 437 |
+
"step": 3050
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"epoch": 60.78431372549019,
|
| 441 |
+
"grad_norm": 0.0057996599934995174,
|
| 442 |
+
"learning_rate": 2.931839402427638e-06,
|
| 443 |
+
"loss": 0.0036,
|
| 444 |
+
"step": 3100
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"epoch": 61.76470588235294,
|
| 448 |
+
"grad_norm": 0.004913662560284138,
|
| 449 |
+
"learning_rate": 2.6206037970743852e-06,
|
| 450 |
+
"loss": 0.0038,
|
| 451 |
+
"step": 3150
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"epoch": 62.745098039215684,
|
| 455 |
+
"grad_norm": 0.004377233795821667,
|
| 456 |
+
"learning_rate": 2.309368191721133e-06,
|
| 457 |
+
"loss": 0.0021,
|
| 458 |
+
"step": 3200
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"epoch": 63.72549019607843,
|
| 462 |
+
"grad_norm": 0.00702957296743989,
|
| 463 |
+
"learning_rate": 1.9981325863678806e-06,
|
| 464 |
+
"loss": 0.0044,
|
| 465 |
+
"step": 3250
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"epoch": 64.70588235294117,
|
| 469 |
+
"grad_norm": 0.004603248089551926,
|
| 470 |
+
"learning_rate": 1.6868969810146283e-06,
|
| 471 |
+
"loss": 0.0039,
|
| 472 |
+
"step": 3300
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"epoch": 65.68627450980392,
|
| 476 |
+
"grad_norm": 0.009002058766782284,
|
| 477 |
+
"learning_rate": 1.3756613756613758e-06,
|
| 478 |
+
"loss": 0.0018,
|
| 479 |
+
"step": 3350
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"epoch": 66.66666666666667,
|
| 483 |
+
"grad_norm": 0.005587454419583082,
|
| 484 |
+
"learning_rate": 1.0644257703081233e-06,
|
| 485 |
+
"loss": 0.004,
|
| 486 |
+
"step": 3400
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"epoch": 67.6470588235294,
|
| 490 |
+
"grad_norm": 0.007010480388998985,
|
| 491 |
+
"learning_rate": 7.531901649548709e-07,
|
| 492 |
+
"loss": 0.0026,
|
| 493 |
+
"step": 3450
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"epoch": 68.62745098039215,
|
| 497 |
+
"grad_norm": 0.004527617245912552,
|
| 498 |
+
"learning_rate": 4.419545596016184e-07,
|
| 499 |
+
"loss": 0.0043,
|
| 500 |
+
"step": 3500
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"epoch": 69.6078431372549,
|
| 504 |
+
"grad_norm": 0.004667927045375109,
|
| 505 |
+
"learning_rate": 1.3071895424836603e-07,
|
| 506 |
+
"loss": 0.0021,
|
| 507 |
+
"step": 3550
|
| 508 |
+
}
|
| 509 |
+
],
|
| 510 |
+
"logging_steps": 50,
|
| 511 |
+
"max_steps": 3570,
|
| 512 |
+
"num_input_tokens_seen": 0,
|
| 513 |
+
"num_train_epochs": 70,
|
| 514 |
+
"save_steps": 500,
|
| 515 |
+
"stateful_callbacks": {
|
| 516 |
+
"TrainerControl": {
|
| 517 |
+
"args": {
|
| 518 |
+
"should_epoch_stop": false,
|
| 519 |
+
"should_evaluate": false,
|
| 520 |
+
"should_log": false,
|
| 521 |
+
"should_save": true,
|
| 522 |
+
"should_training_stop": true
|
| 523 |
+
},
|
| 524 |
+
"attributes": {}
|
| 525 |
+
}
|
| 526 |
+
},
|
| 527 |
+
"total_flos": 2.967289854262272e+16,
|
| 528 |
+
"train_batch_size": 32,
|
| 529 |
+
"trial_name": null,
|
| 530 |
+
"trial_params": null
|
| 531 |
+
}
|
telecom_camelbert_cpt_ft/checkpoint-3570/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b5f9538422ac3ff516fbd5394c01499e3df51a8944c66ec759cd64011b08382
|
| 3 |
+
size 5841
|
telecom_camelbert_cpt_ft/checkpoint-3570/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_camelbert_cpt_ft/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "1",
|
| 14 |
+
"1": "2",
|
| 15 |
+
"2": "3",
|
| 16 |
+
"3": "4",
|
| 17 |
+
"4": "5",
|
| 18 |
+
"5": "6",
|
| 19 |
+
"6": "7",
|
| 20 |
+
"7": "8",
|
| 21 |
+
"8": "9"
|
| 22 |
+
},
|
| 23 |
+
"initializer_range": 0.02,
|
| 24 |
+
"intermediate_size": 3072,
|
| 25 |
+
"label2id": {
|
| 26 |
+
"0": 1,
|
| 27 |
+
"1": 2,
|
| 28 |
+
"2": 3,
|
| 29 |
+
"3": 4,
|
| 30 |
+
"4": 5,
|
| 31 |
+
"5": 6,
|
| 32 |
+
"6": 7,
|
| 33 |
+
"7": 8,
|
| 34 |
+
"8": 9
|
| 35 |
+
},
|
| 36 |
+
"layer_norm_eps": 1e-12,
|
| 37 |
+
"max_position_embeddings": 512,
|
| 38 |
+
"model_type": "bert",
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"pad_token_id": 0,
|
| 42 |
+
"position_embedding_type": "absolute",
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"transformers_version": "4.57.3",
|
| 45 |
+
"type_vocab_size": 2,
|
| 46 |
+
"use_cache": true,
|
| 47 |
+
"vocab_size": 30000,
|
| 48 |
+
"num_labels": 9
|
| 49 |
+
}
|
telecom_camelbert_cpt_ft/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fd4a9bdb6b95896f2d56f3c3b16e3a65a00020c2a659d7c25b29533d04b2ebe
|
| 3 |
+
size 436376588
|
telecom_camelbert_cpt_ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
telecom_camelbert_cpt_ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_camelbert_cpt_ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": false,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"full_tokenizer_file": null,
|
| 50 |
+
"mask_token": "[MASK]",
|
| 51 |
+
"max_length": 512,
|
| 52 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 53 |
+
"never_split": null,
|
| 54 |
+
"pad_to_multiple_of": null,
|
| 55 |
+
"pad_token": "[PAD]",
|
| 56 |
+
"pad_token_type_id": 0,
|
| 57 |
+
"padding_side": "right",
|
| 58 |
+
"sep_token": "[SEP]",
|
| 59 |
+
"stride": 0,
|
| 60 |
+
"strip_accents": null,
|
| 61 |
+
"tokenize_chinese_chars": true,
|
| 62 |
+
"tokenizer_class": "BertTokenizer",
|
| 63 |
+
"truncation_side": "right",
|
| 64 |
+
"truncation_strategy": "longest_first",
|
| 65 |
+
"unk_token": "[UNK]"
|
| 66 |
+
}
|
telecom_camelbert_cpt_ft/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b5f9538422ac3ff516fbd5394c01499e3df51a8944c66ec759cd64011b08382
|
| 3 |
+
size 5841
|
telecom_camelbert_cpt_ft/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_dziribert_final/checkpoint-8000/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": 1,
|
| 14 |
+
"1": 2,
|
| 15 |
+
"2": 3,
|
| 16 |
+
"3": 4,
|
| 17 |
+
"4": 5,
|
| 18 |
+
"5": 6,
|
| 19 |
+
"6": 7,
|
| 20 |
+
"7": 8,
|
| 21 |
+
"8": 9
|
| 22 |
+
},
|
| 23 |
+
"initializer_range": 0.02,
|
| 24 |
+
"intermediate_size": 3072,
|
| 25 |
+
"label2id": {
|
| 26 |
+
"1": 0,
|
| 27 |
+
"2": 1,
|
| 28 |
+
"3": 2,
|
| 29 |
+
"4": 3,
|
| 30 |
+
"5": 4,
|
| 31 |
+
"6": 5,
|
| 32 |
+
"7": 6,
|
| 33 |
+
"8": 7,
|
| 34 |
+
"9": 8
|
| 35 |
+
},
|
| 36 |
+
"layer_norm_eps": 1e-12,
|
| 37 |
+
"max_position_embeddings": 512,
|
| 38 |
+
"model_type": "bert",
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"pad_token_id": 0,
|
| 42 |
+
"position_embedding_type": "absolute",
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"transformers_version": "4.57.3",
|
| 45 |
+
"type_vocab_size": 2,
|
| 46 |
+
"use_cache": true,
|
| 47 |
+
"vocab_size": 50000
|
| 48 |
+
}
|
telecom_dziribert_final/checkpoint-8000/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f1b3eaaad3d15a8ef1c5277c33a0400c15e4ff0af900705aa9144de6f07a699
|
| 3 |
+
size 497816604
|
telecom_dziribert_final/checkpoint-8000/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12fe1e0aa55d8f2aab0ef48874b83f62f443e5e661e80f6a8cbb1102d6235a15
|
| 3 |
+
size 995757451
|
telecom_dziribert_final/checkpoint-8000/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fd0507e2d7b94351111d5badfee1b3ed774bb8916d5cdd863b137ee35c00408
|
| 3 |
+
size 14645
|
telecom_dziribert_final/checkpoint-8000/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18caf96b2e25a0f89da808b5c5aed5fae7eea1994ec0fe0a96700f37665bfa6e
|
| 3 |
+
size 1465
|
telecom_dziribert_final/checkpoint-8000/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
telecom_dziribert_final/checkpoint-8000/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
telecom_dziribert_final/checkpoint-8000/tokenizer_config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"max_len": 512,
|
| 51 |
+
"model_max_length": 512,
|
| 52 |
+
"never_split": null,
|
| 53 |
+
"pad_token": "[PAD]",
|
| 54 |
+
"sep_token": "[SEP]",
|
| 55 |
+
"strip_accents": null,
|
| 56 |
+
"tokenize_chinese_chars": true,
|
| 57 |
+
"tokenizer_class": "BertTokenizer",
|
| 58 |
+
"unk_token": "[UNK]"
|
| 59 |
+
}
|