Code-Lang-Classifier / FastText /FastText-Test.py
kaushik-harsh-99's picture
initial-upload
95f644c
import json
import fasttext
import pandas as pd
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
)
# ============================================================
# CONFIG
# ============================================================
MODEL_FILE = "fasttext_language_classifier.bin"
VALIDATION_FILE = "dataset/validation.jsonl"
TEST_FILE = "dataset/test.jsonl"
# ============================================================
# LOAD MODEL
# ============================================================
print("Loading model...")
model = fasttext.load_model(MODEL_FILE)
print("Model loaded.")
# ============================================================
# EVALUATION
# ============================================================
def evaluate_jsonl(
model,
jsonl_file,
split_name,
):
print(f"\nEvaluating {split_name}")
y_true = []
y_pred = []
processed = 0
with open(
jsonl_file,
"r",
encoding="utf-8",
) as f:
for line in f:
row = json.loads(line)
true_label = row["label"]
# Match FastText training format
text = " ".join(
row["content"].split()
)
labels, probs = model.predict(
text,
k=1,
)
pred_label = (
labels[0]
.replace("__label__", "")
)
y_true.append(true_label)
y_pred.append(pred_label)
processed += 1
if processed % 5000 == 0:
print(
f"Processed {processed:,}"
)
# ========================================================
# ACCURACY
# ========================================================
acc = accuracy_score(
y_true,
y_pred,
)
print(
f"\n{split_name} Accuracy: "
f"{acc:.6f}"
)
# ========================================================
# CLASSIFICATION REPORT
# ========================================================
report = classification_report(
y_true,
y_pred,
output_dict=True,
digits=4,
)
report_df = (
pd.DataFrame(report)
.transpose()
)
report_csv = (
f"{split_name}_classification_report.csv"
)
report_df.to_csv(report_csv)
print(f"Saved {report_csv}")
# ========================================================
# CONFUSION MATRIX
# ========================================================
labels_sorted = sorted(
list(set(y_true))
)
cm = confusion_matrix(
y_true,
y_pred,
labels=labels_sorted,
)
cm_df = pd.DataFrame(
cm,
index=labels_sorted,
columns=labels_sorted,
)
cm_csv = (
f"{split_name}_confusion_matrix.csv"
)
cm_df.to_csv(cm_csv)
print(f"Saved {cm_csv}")
return acc
# ============================================================
# VALIDATION
# ============================================================
validation_accuracy = evaluate_jsonl(
model,
VALIDATION_FILE,
"validation",
)
# ============================================================
# TEST
# ============================================================
test_accuracy = evaluate_jsonl(
model,
TEST_FILE,
"test",
)
# ============================================================
# SUMMARY
# ============================================================
summary = pd.DataFrame([
{
"validation_accuracy": validation_accuracy,
"test_accuracy": test_accuracy,
}
])
summary.to_csv(
"fasttext_summary.csv",
index=False,
)
print("\nSaved fasttext_summary.csv")
print("\n==============================")
print(f"Validation Accuracy: {validation_accuracy:.6f}")
print(f"Test Accuracy: {test_accuracy:.6f}")
print("==============================")
print("\nDone.")