DaCrow13
Deploy to HF Spaces (Clean)
225af6a
"""
Pytest configuration and fixtures for behavioral tests.
"""
import pytest
import numpy as np
import joblib
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from hopcroft_skill_classification_tool_competition.config import DATA_PATHS
from hopcroft_skill_classification_tool_competition.features import (
clean_github_text,
get_label_columns,
load_data_from_db
)
@pytest.fixture(scope="session")
def trained_model():
"""Load the trained model for testing."""
model_path = Path(DATA_PATHS["models_dir"]) / "random_forest_tfidf_gridsearch_smote.pkl"
# Fallback to baseline if SMOTE model not found
if not model_path.exists():
model_path = Path(DATA_PATHS["models_dir"]) / "random_forest_tfidf_gridsearch.pkl"
if not model_path.exists():
pytest.skip(f"Model not found at {model_path}. Please train a model first.")
return joblib.load(model_path)
@pytest.fixture(scope="session")
def tfidf_vectorizer(trained_model):
"""
Extract or reconstruct the TF-IDF vectorizer from the trained model.
Note: In a production setting, you should save and load the vectorizer separately.
For now, we'll create a new one fitted on the training data with max_features=1000.
"""
# Load training features to get vocabulary
features_path = Path(DATA_PATHS["features"])
if not features_path.exists():
pytest.skip(f"Features not found at {features_path}. Please run feature extraction first.")
# For testing purposes, we need to reconstruct the vectorizer with same params as training
# The model expects 1000 features based on the error message
from hopcroft_skill_classification_tool_competition.features import extract_tfidf_features
try:
df = load_data_from_db()
# Use max_features=1000 to match the trained model
_, vectorizer = extract_tfidf_features(df, max_features=1000)
return vectorizer
except Exception as e:
pytest.skip(f"Could not load vectorizer: {e}")
@pytest.fixture(scope="session")
def label_names():
"""Get the list of label names from the database."""
try:
df = load_data_from_db()
return get_label_columns(df)
except Exception as e:
pytest.skip(f"Could not load label names: {e}")
@pytest.fixture
def predict_text(trained_model, tfidf_vectorizer):
"""
Factory fixture that returns a function to predict skills from raw text.
Returns:
Function that takes text and returns predicted label indices
"""
def _predict(text: str, return_proba: bool = False):
"""
Predict skills from raw text.
Args:
text: Raw GitHub issue text
return_proba: If True, return probabilities instead of binary predictions
Returns:
If return_proba=False: indices of predicted labels (numpy array)
If return_proba=True: probability matrix (n_samples, n_labels)
"""
# Clean and transform text
cleaned = clean_github_text(text)
features = tfidf_vectorizer.transform([cleaned]).toarray()
if return_proba:
# Get probabilities for multi-output classifier
# Note: RandomForest returns probabilities per estimator
try:
probas = np.array([
estimator.predict_proba(features)[0][:, 1] # Get probability of class 1
for estimator in trained_model.estimators_
]).T
return probas
except Exception:
# Fallback to binary predictions if probabilities not available
return trained_model.predict(features)
# Get binary predictions
predictions = trained_model.predict(features)[0]
# Return indices of positive labels
return np.where(predictions == 1)[0]
return _predict
@pytest.fixture
def predict_with_labels(predict_text, label_names):
"""
Factory fixture that returns a function to predict skills with label names.
Returns:
Function that takes text and returns list of predicted label names
"""
def _predict(text: str):
"""
Predict skill labels from raw text.
Args:
text: Raw GitHub issue text
Returns:
List of predicted label names
"""
indices = predict_text(text)
return [label_names[i] for i in indices]
return _predict
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line(
"markers", "invariance: Tests for invariance (changes should not affect predictions)"
)
config.addinivalue_line(
"markers", "directional: Tests for directional expectations (changes should affect predictions predictably)"
)
config.addinivalue_line(
"markers", "mft: Minimum Functionality Tests (basic examples with expected outputs)"
)
config.addinivalue_line(
"markers", "training: Tests for model training validation (loss, overfitting, devices)"
)