MarcusBennevall's picture
Upload folder using huggingface_hub
6af0821 verified
from __future__ import annotations
import argparse
import json
from datetime import datetime, timezone
from pathlib import Path
import joblib
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from .labels import LABELS
DEFAULT_DATASET = Path("data/sentence_function_dataset.csv")
DEFAULT_OUTPUT_DIR = Path("artifacts/sentence-function-classifier")
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
def load_dataset(path: Path) -> pd.DataFrame:
data = pd.read_csv(path)
required_columns = {"text", "label"}
missing_columns = required_columns.difference(data.columns)
if missing_columns:
missing = ", ".join(sorted(missing_columns))
raise ValueError(f"Dataset is missing required columns: {missing}")
data = data.dropna(subset=["text", "label"]).copy()
data["text"] = data["text"].astype(str).str.strip()
data["label"] = data["label"].astype(str).str.strip().str.lower()
data = data[data["text"] != ""]
unknown_labels = sorted(set(data["label"]) - set(LABELS))
if unknown_labels:
raise ValueError(f"Unknown labels found: {', '.join(unknown_labels)}")
counts = data["label"].value_counts().reindex(LABELS, fill_value=0)
if (counts < 2).any():
too_small = ", ".join(counts[counts < 2].index)
raise ValueError(f"Each label needs at least two examples. Too small: {too_small}")
return data.sample(frac=1.0, random_state=42).reset_index(drop=True)
def train(
dataset_path: Path = DEFAULT_DATASET,
output_dir: Path = DEFAULT_OUTPUT_DIR,
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
test_size: float = 0.2,
) -> dict[str, object]:
data = load_dataset(dataset_path)
train_texts, test_texts, train_labels, test_labels = train_test_split(
data["text"].tolist(),
data["label"].tolist(),
test_size=test_size,
random_state=42,
stratify=data["label"],
)
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(train_labels)
y_test = label_encoder.transform(test_labels)
embedding_model = SentenceTransformer(embedding_model_name)
x_train = embedding_model.encode(train_texts, normalize_embeddings=True, show_progress_bar=True)
x_test = embedding_model.encode(test_texts, normalize_embeddings=True, show_progress_bar=True)
classifier = LogisticRegression(
C=4.0,
class_weight="balanced",
max_iter=1000,
random_state=42,
)
classifier.fit(x_train, y_train)
predictions = classifier.predict(x_test)
report = classification_report(
y_test,
predictions,
target_names=label_encoder.classes_,
output_dict=True,
zero_division=0,
)
output_dir.mkdir(parents=True, exist_ok=True)
joblib.dump(classifier, output_dir / "classifier.joblib")
joblib.dump(label_encoder, output_dir / "label_encoder.joblib")
metadata = {
"created_at": datetime.now(timezone.utc).isoformat(),
"embedding_model": embedding_model_name,
"labels": list(label_encoder.classes_),
"dataset_path": str(dataset_path),
"dataset_size": len(data),
"test_size": test_size,
"metrics": report,
}
(output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8")
write_model_card(output_dir, metadata)
print(classification_report(y_test, predictions, target_names=label_encoder.classes_, zero_division=0))
print(f"Saved artifacts to {output_dir}")
return metadata
def write_model_card(output_dir: Path, metadata: dict[str, object]) -> None:
metrics = metadata["metrics"]
accuracy = metrics["accuracy"]
labels = ", ".join(metadata["labels"])
card = f"""---
language: en
tags:
- sentence-classification
- sentence-transformers
- text-classification
library_name: scikit-learn
---
# Sentence Function Classifier
This model classifies English sentences as: {labels}.
It embeds sentences with `{metadata["embedding_model"]}` and predicts the final
class with a logistic regression classifier trained on a balanced seed dataset.
## Intended Use
This is designed for educational demos and lightweight sentence-function
analysis. It is not intended as a grammar authority for high-stakes assessment.
## Evaluation
- Dataset size: {metadata["dataset_size"]}
- Held-out test split: {metadata["test_size"]}
- Accuracy: {accuracy:.3f}
The seed dataset is small, so the metrics should be treated as a smoke test
rather than a final benchmark.
"""
(output_dir / "README.md").write_text(card, encoding="utf-8")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the sentence function classifier.")
parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET)
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
parser.add_argument("--embedding-model", default=DEFAULT_EMBEDDING_MODEL)
parser.add_argument("--test-size", type=float, default=0.2)
return parser.parse_args()
def main() -> None:
args = parse_args()
train(
dataset_path=args.dataset,
output_dir=args.output_dir,
embedding_model_name=args.embedding_model,
test_size=args.test_size,
)
if __name__ == "__main__":
main()