aaa / training /train_bloom_classifier.py
work-sejal
Deploy AI service with FastAPI
70ea7be
Raw
History Blame Contribute Delete
10.1 kB
"""Bloom Classifier training pipeline.
Trains a TF-IDF + LogisticRegression model for Bloom taxonomy classification.
Target: bloom_level (6 classes: Remember, Understand, Apply, Analyze, Evaluate, Create).
Features: question_text (TF-IDF).
Primary metric: macro F1.
"""
import logging
from datetime import datetime, timezone
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
classification_report,
confusion_matrix,
f1_score,
)
from sklearn.preprocessing import LabelEncoder
from app.core.config import settings
from app.core.exceptions import TrainingError
from training.base_trainer import BaseTrainer, TrainingResult
logger = logging.getLogger(__name__)
class BloomClassifierTrainer(BaseTrainer):
"""TF-IDF + LogisticRegression for Bloom taxonomy classification.
Target: bloom_level (6 classes: Remember, Understand, Apply, Analyze, Evaluate, Create)
Features: question_text (TF-IDF)
Primary metric: macro F1
"""
@property
def model_name(self) -> str:
return "bloom_classifier"
@property
def model_version(self) -> str:
return "bloom_classifier_v2_baseline_001"
@property
def table_name(self) -> str:
return "training_bloom_classification"
def train(self, train_df: pd.DataFrame, val_df: pd.DataFrame) -> dict:
"""Train TF-IDF + LogisticRegression(multinomial).
Algorithm:
1. Fit TF-IDF vectorizer on train question_text
2. Encode bloom_level labels with LabelEncoder
3. Fit LogisticRegression(multi_class="multinomial", C=1.0, solver="lbfgs",
max_iter=1000, random_state=seed)
4. Return {"model": logreg, "vectorizer": tfidf, "label_encoder": le}
"""
tfidf = TfidfVectorizer(
max_features=8000,
ngram_range=(1, 2),
sublinear_tf=True,
)
X_train = tfidf.fit_transform(train_df["question_text"])
le = LabelEncoder()
y_train = le.fit_transform(train_df["bloom_level"])
logreg = LogisticRegression(
multi_class="multinomial",
C=1.0,
solver="lbfgs",
max_iter=1000,
random_state=self._seed,
)
logreg.fit(X_train, y_train)
logger.info(
"Bloom Classifier trained β€” %d features, %d classes",
X_train.shape[1],
len(le.classes_),
)
return {"model": logreg, "vectorizer": tfidf, "label_encoder": le}
def evaluate(self, artifacts: dict, df: pd.DataFrame, split_name: str) -> dict:
"""Evaluate model on a split.
Computes: macro F1, weighted F1, per-class precision/recall/f1,
confusion matrix.
"""
model = artifacts["model"]
tfidf = artifacts["vectorizer"]
le = artifacts["label_encoder"]
X = tfidf.transform(df["question_text"])
y_true = le.transform(df["bloom_level"])
y_pred = model.predict(X)
# F1 scores
macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
weighted_f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
# Per-class metrics
report = classification_report(
y_true, y_pred,
target_names=le.classes_,
output_dict=True,
zero_division=0,
)
per_class = {}
for class_name in le.classes_:
if class_name in report:
per_class[class_name] = {
"precision": round(report[class_name]["precision"], 4),
"recall": round(report[class_name]["recall"], 4),
"f1": round(report[class_name]["f1-score"], 4),
"support": int(report[class_name]["support"]),
}
# Confusion matrix
cm = confusion_matrix(y_true, y_pred).tolist()
metrics = {
"macro_f1": round(macro_f1, 4),
"weighted_f1": round(weighted_f1, 4),
"per_class": per_class,
"confusion_matrix": cm,
}
logger.info(
"%s metrics β€” macro_f1: %.4f, weighted_f1: %.4f",
split_name, macro_f1, weighted_f1,
)
return metrics
def _check_baseline(self, metrics: dict) -> None:
"""Verify macro F1 > (1 / num_classes). Raise TrainingError if not met."""
test_metrics = metrics.get("metrics", {}).get("test", {})
macro_f1 = test_metrics.get("macro_f1", 0.0)
num_classes = len(test_metrics.get("per_class", {}))
# Fallback: use validation metrics if test not available
if num_classes == 0:
val_metrics = metrics.get("metrics", {}).get("validation", {})
macro_f1 = val_metrics.get("macro_f1", 0.0)
num_classes = len(val_metrics.get("per_class", {}))
if num_classes == 0:
raise TrainingError(
"Cannot compute baseline: no classes found in metrics.",
model_name=self.model_name,
)
baseline = 1.0 / num_classes
if macro_f1 <= baseline:
raise TrainingError(
f"Macro F1 ({macro_f1:.4f}) does not exceed "
f"random baseline ({baseline:.4f} = 1/{num_classes}). "
f"Model is not better than random.",
model_name=self.model_name,
)
logger.info(
"Baseline check passed β€” macro F1 %.4f > baseline %.4f",
macro_f1, baseline,
)
def _build_metrics(
self,
val_metrics: dict,
test_metrics: dict,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: pd.DataFrame,
) -> dict:
"""Assemble full metrics.json content."""
return {
"model_name": self.model_name,
"model_version": self.model_version,
"dataset_version": settings.ai_service_version,
"trained_at": datetime.now(timezone.utc).isoformat(),
"seed": self._seed,
"split_counts": {
"train": len(train_df),
"validation": len(val_df),
"test": len(test_df),
},
"metrics": {
"validation": val_metrics,
"test": test_metrics,
},
"limitations": [
"Trained on synthetic data only.",
"6 classes with imbalanced distribution β€” Create (~2%) and Evaluate (~4%) are rare.",
"Macro F1 is the primary metric; per-class recall may be low for rare classes.",
"TF-IDF features do not capture semantic similarity beyond n-gram overlap.",
],
}
def _build_training_config(
self,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: pd.DataFrame,
) -> dict:
"""Build training_config.json with hyperparameters."""
return {
"model_name": self.model_name,
"model_version": self.model_version,
"dataset_version": settings.ai_service_version,
"seed": self._seed,
"split_counts": {
"train": len(train_df),
"validation": len(val_df),
"test": len(test_df),
},
"hyperparameters": {
"tfidf_max_features": 8000,
"ngram_range": [1, 2],
"sublinear_tf": True,
"logreg_C": 1.0,
"logreg_solver": "lbfgs",
"logreg_max_iter": 1000,
"logreg_multi_class": "multinomial",
},
"feature_columns": ["question_text"],
"target_column": "bloom_level",
"algorithm": "LogisticRegression(multinomial)",
}
def _build_model_card(self, metrics: dict) -> str:
"""Generate model_card.md content."""
val_metrics = metrics.get("metrics", {}).get("validation", {})
test_metrics = metrics.get("metrics", {}).get("test", {})
card = f"""# Model Card: Bloom Classifier
## Model Details
- **Model Name:** {self.model_name}
- **Model Version:** {self.model_version}
- **Algorithm:** TF-IDF + LogisticRegression (multinomial)
- **Framework:** scikit-learn
- **Trained At:** {metrics.get("trained_at", "N/A")}
- **Seed:** {self._seed}
## Intended Use
Automatically classify questions by Bloom's taxonomy cognitive level.
Used in the Bloom classification endpoint to predict one of 6 levels:
Remember, Understand, Apply, Analyze, Evaluate, Create.
## Training Data
- **Source:** training_bloom_classification.csv (synthetic dataset v2)
- **Split Counts:** train={metrics.get("split_counts", {}).get("train", "N/A")}, \
validation={metrics.get("split_counts", {}).get("validation", "N/A")}, \
test={metrics.get("split_counts", {}).get("test", "N/A")}
- **Feature:** question_text (TF-IDF vectorized, max_features=8000, ngram_range=(1,2))
- **Target:** bloom_level (6 classes)
## Metrics
### Validation Set
- Macro F1: {val_metrics.get("macro_f1", "N/A")}
- Weighted F1: {val_metrics.get("weighted_f1", "N/A")}
### Test Set
- Macro F1: {test_metrics.get("macro_f1", "N/A")}
- Weighted F1: {test_metrics.get("weighted_f1", "N/A")}
## Known Limitations
- Trained on synthetic data only β€” performance on real classroom questions is unknown.
- Class imbalance: Create (~2%) and Evaluate (~4%) are rare; recall on these classes may be low.
- TF-IDF features do not capture semantic similarity beyond n-gram overlap.
- Macro F1 is the primary metric; accuracy alone would mask poor performance on rare classes.
## Fallback Behavior
When the model is not loaded or confidence is below the threshold (0.55),
the system falls back to keyword heuristic classification:
define/list β†’ Remember; explain β†’ Understand; calculate/use β†’ Apply;
compare/contrast β†’ Analyze; justify β†’ Evaluate; design β†’ Create.
"""
return card