| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from datetime import datetime, timezone |
| from pathlib import Path |
|
|
| import joblib |
| import pandas as pd |
| from sentence_transformers import SentenceTransformer |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.metrics import classification_report |
| from sklearn.model_selection import train_test_split |
| from sklearn.preprocessing import LabelEncoder |
|
|
| from .labels import LABELS |
|
|
|
|
| DEFAULT_DATASET = Path("data/sentence_function_dataset.csv") |
| DEFAULT_OUTPUT_DIR = Path("artifacts/sentence-function-classifier") |
| DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
|
| def load_dataset(path: Path) -> pd.DataFrame: |
| data = pd.read_csv(path) |
| required_columns = {"text", "label"} |
| missing_columns = required_columns.difference(data.columns) |
| if missing_columns: |
| missing = ", ".join(sorted(missing_columns)) |
| raise ValueError(f"Dataset is missing required columns: {missing}") |
|
|
| data = data.dropna(subset=["text", "label"]).copy() |
| data["text"] = data["text"].astype(str).str.strip() |
| data["label"] = data["label"].astype(str).str.strip().str.lower() |
| data = data[data["text"] != ""] |
|
|
| unknown_labels = sorted(set(data["label"]) - set(LABELS)) |
| if unknown_labels: |
| raise ValueError(f"Unknown labels found: {', '.join(unknown_labels)}") |
|
|
| counts = data["label"].value_counts().reindex(LABELS, fill_value=0) |
| if (counts < 2).any(): |
| too_small = ", ".join(counts[counts < 2].index) |
| raise ValueError(f"Each label needs at least two examples. Too small: {too_small}") |
|
|
| return data.sample(frac=1.0, random_state=42).reset_index(drop=True) |
|
|
|
|
| def train( |
| dataset_path: Path = DEFAULT_DATASET, |
| output_dir: Path = DEFAULT_OUTPUT_DIR, |
| embedding_model_name: str = DEFAULT_EMBEDDING_MODEL, |
| test_size: float = 0.2, |
| ) -> dict[str, object]: |
| data = load_dataset(dataset_path) |
|
|
| train_texts, test_texts, train_labels, test_labels = train_test_split( |
| data["text"].tolist(), |
| data["label"].tolist(), |
| test_size=test_size, |
| random_state=42, |
| stratify=data["label"], |
| ) |
|
|
| label_encoder = LabelEncoder() |
| y_train = label_encoder.fit_transform(train_labels) |
| y_test = label_encoder.transform(test_labels) |
|
|
| embedding_model = SentenceTransformer(embedding_model_name) |
| x_train = embedding_model.encode(train_texts, normalize_embeddings=True, show_progress_bar=True) |
| x_test = embedding_model.encode(test_texts, normalize_embeddings=True, show_progress_bar=True) |
|
|
| classifier = LogisticRegression( |
| C=4.0, |
| class_weight="balanced", |
| max_iter=1000, |
| random_state=42, |
| ) |
| classifier.fit(x_train, y_train) |
|
|
| predictions = classifier.predict(x_test) |
| report = classification_report( |
| y_test, |
| predictions, |
| target_names=label_encoder.classes_, |
| output_dict=True, |
| zero_division=0, |
| ) |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| joblib.dump(classifier, output_dir / "classifier.joblib") |
| joblib.dump(label_encoder, output_dir / "label_encoder.joblib") |
|
|
| metadata = { |
| "created_at": datetime.now(timezone.utc).isoformat(), |
| "embedding_model": embedding_model_name, |
| "labels": list(label_encoder.classes_), |
| "dataset_path": str(dataset_path), |
| "dataset_size": len(data), |
| "test_size": test_size, |
| "metrics": report, |
| } |
| (output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8") |
| write_model_card(output_dir, metadata) |
|
|
| print(classification_report(y_test, predictions, target_names=label_encoder.classes_, zero_division=0)) |
| print(f"Saved artifacts to {output_dir}") |
| return metadata |
|
|
|
|
| def write_model_card(output_dir: Path, metadata: dict[str, object]) -> None: |
| metrics = metadata["metrics"] |
| accuracy = metrics["accuracy"] |
| labels = ", ".join(metadata["labels"]) |
| card = f"""--- |
| language: en |
| tags: |
| - sentence-classification |
| - sentence-transformers |
| - text-classification |
| library_name: scikit-learn |
| --- |
| |
| # Sentence Function Classifier |
| |
| This model classifies English sentences as: {labels}. |
| |
| It embeds sentences with `{metadata["embedding_model"]}` and predicts the final |
| class with a logistic regression classifier trained on a balanced seed dataset. |
| |
| ## Intended Use |
| |
| This is designed for educational demos and lightweight sentence-function |
| analysis. It is not intended as a grammar authority for high-stakes assessment. |
| |
| ## Evaluation |
| |
| - Dataset size: {metadata["dataset_size"]} |
| - Held-out test split: {metadata["test_size"]} |
| - Accuracy: {accuracy:.3f} |
| |
| The seed dataset is small, so the metrics should be treated as a smoke test |
| rather than a final benchmark. |
| """ |
| (output_dir / "README.md").write_text(card, encoding="utf-8") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Train the sentence function classifier.") |
| parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET) |
| parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument("--embedding-model", default=DEFAULT_EMBEDDING_MODEL) |
| parser.add_argument("--test-size", type=float, default=0.2) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| train( |
| dataset_path=args.dataset, |
| output_dir=args.output_dir, |
| embedding_model_name=args.embedding_model, |
| test_size=args.test_size, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|