kaushik-harsh-99's picture
initial-upload
95f644c
import json
import os
import time
import fasttext
import pandas as pd
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
)
# ============================================================
# CONFIG
# ============================================================
TRAIN_FILE = "fasttext_train.txt"
VALIDATION_JSONL = "dataset/validation.jsonl"
TEST_JSONL = "dataset/test.jsonl"
MODEL_FILE = "fasttext_language_classifier.bin"
EPOCHS = 25
LR = 0.7
DIM = 50
WORD_NGRAMS = 3
MINN = 2
MAXN = 5
MIN_COUNT = 100
BUCKET = 50000
THREADS = os.cpu_count()
# ============================================================
# TRAIN
# ============================================================
print("Training FastText...")
print()
start = time.time()
model = fasttext.train_supervised(
input=TRAIN_FILE,
epoch=EPOCHS,
lr=LR,
dim=DIM,
wordNgrams=WORD_NGRAMS,
minn=MINN,
maxn=MAXN,
minCount=MIN_COUNT,
bucket=BUCKET,
loss="softmax",
thread=THREADS,
verbose=2,
)
elapsed = time.time() - start
print()
print(f"Training completed in {elapsed:.1f}s")
# ============================================================
# LABEL DEBUG
# ============================================================
print()
print("Labels found by FastText:")
print(f"Count: {len(model.labels)}")
for label in model.labels:
print(label)
# ============================================================
# SAVE MODEL
# ============================================================
model.save_model(MODEL_FILE)
size_mb = os.path.getsize(MODEL_FILE) / 1024 / 1024
print()
print(f"Saved model: {MODEL_FILE}")
print(f"Model size: {size_mb:.2f} MB")
# ============================================================
# EVALUATION
# ============================================================
def evaluate_jsonl(
model,
jsonl_file,
split_name,
):
print()
print(f"Evaluating {split_name}")
y_true = []
y_pred = []
processed = 0
with open(
jsonl_file,
"r",
encoding="utf-8",
) as f:
for line in f:
row = json.loads(line)
true_label = row["label"]
text = " ".join(
str(row["content"]).split()
)
labels, probs = model.predict(
text,
k=1,
)
pred_label = (
labels[0]
.replace("__label__", "")
)
y_true.append(true_label)
y_pred.append(pred_label)
processed += 1
if processed % 5000 == 0:
print(
f"Processed {processed:,}"
)
# ========================================================
# ACCURACY
# ========================================================
accuracy = accuracy_score(
y_true,
y_pred,
)
print()
print(
f"{split_name} Accuracy: "
f"{accuracy:.6f}"
)
# ========================================================
# CLASSIFICATION REPORT
# ========================================================
report = classification_report(
y_true,
y_pred,
output_dict=True,
digits=4,
)
report_df = pd.DataFrame(
report
).transpose()
report_file = (
f"{split_name}_classification_report.csv"
)
report_df.to_csv(report_file)
print(f"Saved {report_file}")
# ========================================================
# CONFUSION MATRIX
# ========================================================
labels_sorted = sorted(
list(set(y_true))
)
cm = confusion_matrix(
y_true,
y_pred,
labels=labels_sorted,
)
cm_df = pd.DataFrame(
cm,
index=labels_sorted,
columns=labels_sorted,
)
cm_file = (
f"{split_name}_confusion_matrix.csv"
)
cm_df.to_csv(cm_file)
print(f"Saved {cm_file}")
return accuracy
# ============================================================
# VALIDATION
# ============================================================
validation_accuracy = evaluate_jsonl(
model,
VALIDATION_JSONL,
"validation",
)
# ============================================================
# TEST
# ============================================================
test_accuracy = evaluate_jsonl(
model,
TEST_JSONL,
"test",
)
# ============================================================
# SUMMARY
# ============================================================
summary = pd.DataFrame(
[
{
"validation_accuracy": validation_accuracy,
"test_accuracy": test_accuracy,
"epochs": EPOCHS,
"lr": LR,
"dim": DIM,
"word_ngrams": WORD_NGRAMS,
"min_count": MIN_COUNT,
"bucket": BUCKET,
"model_size_mb": size_mb,
}
]
)
summary.to_csv(
"fasttext_summary.csv",
index=False,
)
print()
print("=" * 60)
print(f"Validation Accuracy : {validation_accuracy:.6f}")
print(f"Test Accuracy : {test_accuracy:.6f}")
print(f"Model Size (MB) : {size_mb:.2f}")
print("=" * 60)
print()
print("Done.")