Gemma_models / Code /inference_marbertv2.py
houssamboukhalfa's picture
Upload folder using huggingface_hub
f8dd4fe verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Inference script for trained MARBERTv2 telecom classification model.
Loads the trained model and runs predictions on test.csv
"""
import os
import numpy as np
import pandas as pd
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_recall_fscore_support,
classification_report,
confusion_matrix,
)
# -------------------------------------------------------------------
# 1. Paths & config
# -------------------------------------------------------------------
TEST_FILE = "/home/houssam-nojoom/.cache/huggingface/hub/datasets--houssamboukhalfa--telecom-ch1/snapshots/be06acac69aa411636dbe0e3bef5f0072e670765/test_file.csv"
MODEL_DIR = "./telecom_marbertv2_final"
OUTPUT_FILE = "./test_predictions.csv"
MAX_LENGTH = 256
BATCH_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# -------------------------------------------------------------------
# 3. Load test data
# -------------------------------------------------------------------
print(f"Loading test data from: {TEST_FILE}")
test_df = pd.read_csv(TEST_FILE)
print(f"Test samples: {len(test_df)}")
print(f"Columns: {test_df.columns.tolist()}")
# Check if test data has labels
has_labels = "Class" in test_df.columns
if has_labels:
print("Test data contains labels - will compute metrics")
else:
print("Test data has no labels - will only generate predictions")
# -------------------------------------------------------------------
# 4. Load model and tokenizer
# -------------------------------------------------------------------
print(f"\nLoading model from: {MODEL_DIR}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
# Load config to get label mappings
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
config_data = json.load(f)
if 'id2label' in config_data:
id2label = {int(k): v for k, v in config_data['id2label'].items()}
# Create label2id with both string and int keys for robustness
label2id = {}
for k, v in id2label.items():
label2id[v] = k # string key -> int value
try:
label2id[int(v)] = k # int key -> int value (if label is numeric)
except (ValueError, TypeError):
pass
num_labels = len(id2label)
else:
# Fallback: infer from test data if available
if has_labels:
unique_classes = sorted(test_df["Class"].unique())
label2id = {label: idx for idx, label in enumerate(unique_classes)}
id2label = {idx: label for label, idx in label2id.items()}
num_labels = len(unique_classes)
else:
raise ValueError("Cannot determine number of labels without config or test labels")
else:
# Fallback: infer from test data if available
if has_labels:
unique_classes = sorted(test_df["Class"].unique())
label2id = {label: idx for idx, label in enumerate(unique_classes)}
id2label = {idx: label for label, idx in label2id.items()}
num_labels = len(unique_classes)
else:
raise ValueError("Cannot find config.json and test data has no labels")
print(f"Number of classes: {num_labels}")
print(f"Label mapping: {id2label}")
# Load model using AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_DIR,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
)
model = model.to(device)
model.eval()
print("Model loaded successfully!")
# -------------------------------------------------------------------
# 5. Run inference
# -------------------------------------------------------------------
print("\nRunning inference...")
all_predictions = []
all_probabilities = []
# Process in batches for efficiency
for i in range(0, len(test_df), BATCH_SIZE):
batch_texts = test_df["Commentaire client"].iloc[i:i+BATCH_SIZE].tolist()
# Tokenize
inputs = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt",
).to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
all_probabilities.extend(probs.cpu().numpy())
if (i // BATCH_SIZE + 1) % 10 == 0:
print(f"Processed {i + len(batch_texts)}/{len(test_df)} samples...")
print(f"Inference complete! Processed {len(all_predictions)} samples")
# -------------------------------------------------------------------
# 6. Save predictions
# -------------------------------------------------------------------
# Convert predictions to class names
predicted_classes = [id2label[pred] for pred in all_predictions]
# Add predictions to dataframe
test_df["Predicted_Class"] = predicted_classes
test_df["Predicted_Label_ID"] = all_predictions
# Add probability for each class
for idx, class_name in id2label.items():
test_df[f"Prob_{class_name}"] = [probs[idx] for probs in all_probabilities]
# Add confidence (max probability)
test_df["Confidence"] = [max(probs) for probs in all_probabilities]
# Save results
test_df.to_csv(OUTPUT_FILE, index=False)
print(f"\nPredictions saved to: {OUTPUT_FILE}")
# -------------------------------------------------------------------
# 7. Compute metrics (if labels available)
# -------------------------------------------------------------------
if has_labels:
print("\n" + "="*80)
print("EVALUATION METRICS")
print("="*80)
# Convert true labels to indices
true_labels = test_df["Class"].map(label2id).values
pred_labels = np.array(all_predictions)
# Overall metrics
accuracy = accuracy_score(true_labels, pred_labels)
print(f"\nAccuracy: {accuracy:.4f}")
# Weighted metrics (accounts for class imbalance)
precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
true_labels, pred_labels, average='weighted', zero_division=0
)
print(f"\nWeighted Metrics:")
print(f" Precision: {precision_w:.4f}")
print(f" Recall: {recall_w:.4f}")
print(f" F1 Score: {f1_w:.4f}")
# Macro metrics (treats all classes equally)
precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
true_labels, pred_labels, average='macro', zero_division=0
)
print(f"\nMacro Metrics:")
print(f" Precision: {precision_m:.4f}")
print(f" Recall: {recall_m:.4f}")
print(f" F1 Score: {f1_m:.4f}")
# Per-class metrics
print(f"\nPer-Class Metrics:")
per_class_f1 = f1_score(true_labels, pred_labels, average=None, zero_division=0)
per_class_precision, per_class_recall, _, support = precision_recall_fscore_support(
true_labels, pred_labels, average=None, zero_division=0
)
for idx in range(num_labels):
class_name = id2label[idx]
print(f"\n {class_name}:")
print(f" Precision: {per_class_precision[idx]:.4f}")
print(f" Recall: {per_class_recall[idx]:.4f}")
print(f" F1 Score: {per_class_f1[idx]:.4f}")
print(f" Support: {int(support[idx])}")
# Classification report
print("\n" + "="*80)
print("DETAILED CLASSIFICATION REPORT")
print("="*80)
target_names = [id2label[i] for i in range(num_labels)]
print(classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0))
# Confusion matrix
print("\n" + "="*80)
print("CONFUSION MATRIX")
print("="*80)
cm = confusion_matrix(true_labels, pred_labels)
# Print confusion matrix with labels
print("\nTrue \\ Predicted", end="")
for i in range(num_labels):
print(f"\t{id2label[i][:8]}", end="")
print()
for i in range(num_labels):
print(f"{id2label[i][:15]:<15}", end="")
for j in range(num_labels):
print(f"\t{cm[i][j]}", end="")
print()
# Save confusion matrix to CSV
cm_df = pd.DataFrame(
cm,
index=[id2label[i] for i in range(num_labels)],
columns=[id2label[i] for i in range(num_labels)]
)
cm_df.to_csv("./confusion_matrix.csv")
print("\nConfusion matrix saved to: ./confusion_matrix.csv")
# -------------------------------------------------------------------
# 8. Show sample predictions
# -------------------------------------------------------------------
print("\n" + "="*80)
print("SAMPLE PREDICTIONS")
print("="*80)
# Show first 5 predictions
num_samples = min(5, len(test_df))
for i in range(num_samples):
print(f"\nSample {i+1}:")
print(f"Text: {test_df['Commentaire client'].iloc[i]}")
if has_labels:
print(f"True Class: {test_df['Class'].iloc[i]}")
print(f"Predicted Class: {predicted_classes[i]}")
print(f"Confidence: {test_df['Confidence'].iloc[i]:.4f}")
print(f"Probabilities:")
for idx, class_name in id2label.items():
print(f" {class_name}: {all_probabilities[i][idx]:.4f}")
print("\n" + "="*80)
print("Inference completed successfully!")
print("="*80)