Sky-Blue-da-ba-dee's picture
added files
ac9ddbb
"""Module for evaluating models on test set."""
import argparse
import json
import os
import time
import dagshub
import joblib
import mlflow
import numpy as np
import pandas as pd
from setfit import SetFitModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .utils import load_dataset_splits, parse_labels_column
LABELS = {
"java": ["summary", "Ownership", "Expand", "usage", "Pointer", "deprecation", "rational"],
"python": ["Usage", "Parameters", "DevelopmentNotes", "Expand", "Summary"],
"pharo": [
"Keyimplementationpoints",
"Example",
"Responsibilities",
"Intent",
"Keymessages",
"Collaborators",
],
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="TheClouds", mlflow=True)
def evaluate_and_benchmark(lang, model_type, model_path, data_path, metrics_output_path):
"""Load a trained model, run detailed benchmarking for performance and metrics,
and log the results to a new MLflow run.
"""
mlflow.set_experiment("Model Benchmarking")
print(f"Starting Evaluation & Benchmarking for language: {lang} and model: {model_type}")
with mlflow.start_run(run_name=f"evaluation_local_{lang}_{model_type}"):
mlflow.log_param("language", lang)
mlflow.log_param("model_type", model_type)
mlflow.log_param("model_path", model_path)
mlflow.log_param("data_path", data_path)
avg_runtime_sec = 0.0
avg_gflops = 0.0
# -----------------------
# SETFIT
# -----------------------
if model_type == "setfit":
ds = load_dataset_splits(base_dir=data_path, langs=[lang])
eval_df = parse_labels_column(ds[f"{lang}_test"])
x_eval = eval_df["combo"].astype(str).tolist()
y_true = np.array(eval_df["labels"].tolist(), dtype=int)
model = SetFitModel.from_pretrained(model_path)
with torch.profiler.profile(with_flops=True) as p:
begin = time.time()
for _ in range(10):
y_pred = model(x_eval)
total_runtime = time.time() - begin
avg_runtime_sec = total_runtime / 10
avg_gflops = (sum(k.flops for k in p.key_averages()) / 1e9) / 10
y_pred = np.array(y_pred)
# -----------------------
# RANDOM FOREST
# -----------------------
elif model_type == "random_forest":
ds = load_dataset_splits(base_dir=data_path, langs=[lang])
eval_df = parse_labels_column(ds[f"{lang}_test"])
x_eval = eval_df["combo"].astype(str).tolist()
y_true = np.array(eval_df["labels"].tolist(), dtype=int)
model = joblib.load(f"{model_path}.joblib")
begin = time.time()
for _ in range(10):
y_pred = model.predict(x_eval)
total_runtime = time.time() - begin
avg_runtime_sec = total_runtime / 10
avg_gflops = 0.0 # not applicable
y_pred = np.array(y_pred)
# -----------------------
# TRANSFORMER
# -----------------------
elif model_type == "transformer":
test_csv_path = os.path.join(data_path, f"{lang}_test.csv")
if not os.path.exists(test_csv_path):
raise FileNotFoundError(f"Test CSV for transformer not found: {test_csv_path}")
df_test = pd.read_csv(test_csv_path)
df_test = parse_labels_column(df_test)
# Ensure 'combo' exists
if "combo" not in df_test.columns:
df_test["combo"] = (
df_test["comment_sentence"].astype(str) + " | " + df_test["class"].astype(str)
)
texts = df_test["combo"].astype(str).tolist()
y_true = np.array(df_test["labels"].tolist(), dtype=int)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path).to(DEVICE)
model.eval()
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=128, # keep consistent with training config
return_tensors="pt",
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
with torch.no_grad():
with torch.profiler.profile(with_flops=True) as p:
begin = time.time()
for _ in range(10):
outputs = model(**enc)
total_runtime = time.time() - begin
logits = outputs.logits
probs = torch.sigmoid(logits)
y_pred = (probs > 0.5).long().cpu().numpy()
avg_runtime_sec = total_runtime / 10
avg_gflops = (sum(k.flops for k in p.key_averages()) / 1e9) / 10
else:
raise ValueError(f"Unsupported model_type: {model_type}")
print(f"Avg runtime in seconds: {avg_runtime_sec:.4f}")
mlflow.log_metric("avg_runtime_sec", avg_runtime_sec)
mlflow.log_metric("avg_gflops", avg_gflops)
# -----------------------
# Manual per-label metrics (common)
# -----------------------
scores = []
y_true_transposed = y_true.T
y_pred_transposed = y_pred.T
for i in range(len(y_pred_transposed)):
tp = np.logical_and(y_true_transposed[i] == 1, y_pred_transposed[i] == 1).sum()
fp = np.logical_and(y_true_transposed[i] == 0, y_pred_transposed[i] == 1).sum()
fn = np.logical_and(y_true_transposed[i] == 1, y_pred_transposed[i] == 0).sum()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0.0
scores.append(
{
"lan": lang,
"cat": LABELS[lang][i],
"precision": precision,
"recall": recall,
"f1": f1,
}
)
lan_scores_df = pd.DataFrame(scores)
avg_f1 = lan_scores_df["f1"].mean()
avg_precision = lan_scores_df["precision"].mean()
avg_recall = lan_scores_df["recall"].mean()
mlflow.log_metric("avg_f1_score", avg_f1)
mlflow.log_metric("avg_precision", avg_precision)
mlflow.log_metric("avg_recall", avg_recall)
dvc_metrics = {
"avg_f1_score": avg_f1,
"avg_precision": avg_precision,
"avg_recall": avg_recall,
"avg_runtime_sec": avg_runtime_sec,
"avg_gflops": avg_gflops,
}
os.makedirs(os.path.dirname(metrics_output_path), exist_ok=True)
with open(metrics_output_path, "w") as f:
json.dump(dvc_metrics, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, required=True)
parser.add_argument("--model_type", type=str, required=True)
parser.add_argument(
"--data_path",
type=str,
default="data/raw",
help=(
"Path to evaluation data. "
"For setfit/random_forest: base dir with raw CSVs (e.g. data/raw). "
"For transformer: directory with {lang}_test.csv (e.g. data/processed/transformer)."
),
)
args = parser.parse_args()
evaluate_and_benchmark(
lang=args.lang,
model_type=args.model_type,
model_path=f"models/{args.lang}/{args.model_type}",
data_path=args.data_path,
metrics_output_path=f"reports/metrics/{args.lang}/{args.model_type}_metrics.json",
)