Gemma_models / Code /inference_camelbert.py
houssamboukhalfa's picture
Upload folder using huggingface_hub
f8dd4fe verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Inference script for CPT+finetuned MARBERTv2 telecom classification model.
Loads the model from ./telecom_marbertv2_cpt_ft and runs predictions on test.csv
"""
import os
import numpy as np
import pandas as pd
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoConfig,
)
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_recall_fscore_support,
classification_report,
confusion_matrix,
)
# -------------------------------------------------------------------
# 1. Paths & config
# -------------------------------------------------------------------
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
MODEL_DIR = "./telecom_camelbert_cpt_ft"
OUTPUT_FILE = "./test_predictions_camelbert_cpt_ft.csv"
MAX_LENGTH = 256
BATCH_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# -------------------------------------------------------------------
# 2. Load test data
# -------------------------------------------------------------------
print(f"Loading test data from: {TEST_FILE}")
test_df = pd.read_csv(TEST_FILE)
print(f"Test samples: {len(test_df)}")
print(f"Columns: {test_df.columns.tolist()}")
# Check if test data has labels
has_labels = "Class" in test_df.columns
if has_labels:
print("Test data contains labels - will compute metrics")
else:
print("Test data has no labels - will only generate predictions")
# -------------------------------------------------------------------
# 3. Load model and tokenizer
# -------------------------------------------------------------------
print(f"\nLoading model from: {MODEL_DIR}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
# Load config to get label mappings
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
config_data = json.load(f)
if 'id2label' in config_data:
id2label = {int(k): int(v) for k, v in config_data['id2label'].items()}
# Create label2id with both string and int keys for robustness
label2id = {}
for k, v in id2label.items():
label2id[v] = k # int key -> int value
label2id[str(v)] = k # string key -> int value
num_labels = len(id2label)
else:
# Fallback: infer from test data if available
if has_labels:
unique_classes = sorted(test_df["Class"].unique())
label2id = {label: idx for idx, label in enumerate(unique_classes)}
id2label = {idx: label for label, idx in label2id.items()}
num_labels = len(unique_classes)
else:
raise ValueError("Cannot determine number of labels without config or test labels")
else:
# Fallback: infer from test data if available
if has_labels:
unique_classes = sorted(test_df["Class"].unique())
label2id = {label: idx for idx, label in enumerate(unique_classes)}
id2label = {idx: label for label, idx in label2id.items()}
num_labels = len(unique_classes)
else:
raise ValueError("Cannot find config.json and test data has no labels")
print(f"Number of classes: {num_labels}")
print(f"Label mapping: {id2label}")
# Load model
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model = model.to(device)
model.eval()
print("Model loaded successfully!")
# -------------------------------------------------------------------
# 4. Run inference
# -------------------------------------------------------------------
print("\nRunning inference...")
all_predictions = []
all_probabilities = []
# Process in batches for efficiency
for i in range(0, len(test_df), BATCH_SIZE):
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
# Tokenize
inputs = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt",
).to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
all_probabilities.extend(probs.cpu().numpy())
if (i // BATCH_SIZE + 1) % 10 == 0:
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
print(f"Inference complete! Processed {len(all_predictions)} samples")
# -------------------------------------------------------------------
# 5. Save predictions
# -------------------------------------------------------------------
# Convert predictions to class names (1-9)
predicted_classes = [id2label[pred] for pred in all_predictions]
# Add predictions to dataframe
test_df["Predicted_Class"] = predicted_classes
test_df["Predicted_Label_ID"] = all_predictions
# Add probability for each class
for idx, class_name in id2label.items():
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
# Add confidence (max probability)
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
# Save results
test_df.to_csv(OUTPUT_FILE, index=False)
print(f"\nPredictions saved to: {OUTPUT_FILE}")
# -------------------------------------------------------------------
# 6. Compute metrics (if labels available)
# -------------------------------------------------------------------
if has_labels:
print("\n" + "="*80)
print("EVALUATION METRICS")
print("="*80)
# Convert true labels to indices
true_labels = test_df["Class"].map(label2id).values
pred_labels = np.array(all_predictions)
# Overall metrics
accuracy = accuracy_score(true_labels, pred_labels)
print(f"\nAccuracy: {accuracy:.4f}")
# Weighted metrics (accounts for class imbalance)
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
true_labels, pred_labels, average='weighted', zero_division=0
)
print(f"\nWeighted Metrics:")
print(f" Precision: {precision_w:.4f}")
print(f" Recall: {recall_w:.4f}")
print(f" F1 Score: {f1_w:.4f}")
# Macro metrics (treats all classes equally)
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
true_labels, pred_labels, average='macro', zero_division=0
)
print(f"\nMacro Metrics:")
print(f" Precision: {precision_m:.4f}")
print(f" Recall: {recall_m:.4f}")
print(f" F1 Score: {f1_m:.4f}")
# Per-class metrics
print(f"\nPer-Class Metrics:")
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
true_labels, pred_labels, average=None, zero_division=0
)
for idx in range(num_labels):
class_name = id2label[idx]
print(f"\n Class {class_name}:")
print(f" Precision: {per_class_precision[idx]:.4f}")
print(f" Recall: {per_class_recall[idx]:.4f}")
print(f" F1 Score: {per_class_f1[idx]:.4f}")
print(f" Support: {int(support[idx])}")
# Classification report
print("\n" + "="*80)
print("DETAILED CLASSIFICATION REPORT")
print("="*80)
target_names = [str(id2label[i]) for i in range(num_labels)]
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
# Confusion matrix
print("\n" + "="*80)
print("CONFUSION MATRIX")
print("="*80)
cm = confusion_matrix(true_labels, pred_labels)
# Print confusion matrix with labels
print("\nTrue \\ Predicted", end="")
for i in range(num_labels):
print(f"\t{id2label[i]}", end="")
print()
for i in range(num_labels):
print(f"{id2label[i]:<15}", end="")
for j in range(num_labels):
print(f"\t{cm[i][j]}", end="")
print()
# Save confusion matrix to CSV
cm_df = pd.DataFrame(
cm,
index=[str(id2label[i]) for i in range(num_labels)],
columns=[str(id2label[i]) for i in range(num_labels)]
)
cm_df.to_csv("./confusion_matrix_marbertv2_cpt_ft.csv")
print("\nConfusion matrix saved to: ./confusion_matrix_marbertv2_cpt_ft.csv")
# -------------------------------------------------------------------
# 7. Show sample predictions
# -------------------------------------------------------------------
print("\n" + "="*80)
print("SAMPLE PREDICTIONS (CPT+Finetuned MARBERTv2)")
print("="*80)
# Show first 5 predictions
num_samples = min(5, len(test_df))
for i in range(num_samples):
print(f"\nSample {i+1}:")
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
if has_labels:
print(f"True Class: {test_df['Class'].iloc[i]}")
print(f"Predicted Class: {predicted_classes[i]}")
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
print(f"Probabilities:")
for idx, class_name in id2label.items():
print(f" Class {class_name}: {all_probabilities[i][idx]:.4f}")
print("\n" + "="*80)
print("Inference completed successfully!")
print(f"Model used: CPT+Finetuned MARBERTv2 from {MODEL_DIR}")
print("="*80)