Spaces:
Sleeping
Sleeping
| import re | |
| import numpy as np | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.model_selection import train_test_split, cross_val_score | |
| from app.core.config import settings | |
| from typing import List, Optional | |
| def load_pipeline(path): | |
| with open(path, "rb") as f: | |
| pipeline = pickle.load(f) | |
| return pipeline | |
| class Classifier: | |
| def __init__( | |
| self, | |
| tfidf, | |
| abbreviations, | |
| master_index, | |
| le_type, | |
| le_category, | |
| le_topic, | |
| le_intent, | |
| models=None, | |
| df=None, | |
| ): | |
| self.tfidf = tfidf | |
| self.abbreviations = abbreviations | |
| self.master_index = master_index | |
| self.le_type = le_type | |
| self.le_category = le_category | |
| self.le_topic = le_topic | |
| self.le_intent = le_intent | |
| model_path = settings.embeddings_path / "mdbr-leaf-mt" | |
| if model_path.exists(): | |
| self.embedding_model = SentenceTransformer(str(model_path)) | |
| else: | |
| self.embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-mt") | |
| # Prediction thresholds: below these, the field is set to None entirely | |
| self.threshold = { | |
| "type": 0.4, | |
| "category": 0.4, | |
| "topic": 0.5, | |
| "intent": 0.6 | |
| } | |
| # Filter thresholds: above these, the field is used as a hard ChromaDB filter | |
| # Kept separate so you can tune "when to predict" vs "when to filter" independently | |
| self.filter_threshold = { | |
| "type": 0.65, | |
| "category": 0.65, | |
| "topic": 0.70, | |
| } | |
| # If trained models are passed | |
| if models is not None: | |
| self.models = models | |
| else: | |
| if df is None: | |
| raise ValueError("Either provide trained models or provide df to train.") | |
| self.models = self.train_models(df) | |
| def _build_filter(self, result): | |
| # If type confidence doesn't clear the filter bar, the entire filter | |
| # is unreliable — return None so retrieval does a full scan instead. | |
| if result.get("type_conf", 0) < self.filter_threshold["type"]: | |
| return None | |
| # --- Hard AND anchors (always reliable) --- | |
| hard_conditions = [] | |
| hard_conditions.append({"type": result["type"]}) | |
| # intent — handles special case for "count" to include "detail" | |
| intent = result.get("intent") or "detail" | |
| if intent == "count": | |
| hard_conditions.append({"$or": [{"intent": "count"}, {"intent": "detail"}]}) | |
| else: | |
| hard_conditions.append({"intent": intent}) | |
| # --- Soft OR hints (category / topic) --- | |
| # A document only needs to match ONE of these to pass. | |
| # This avoids dropping valid docs that are tagged with category but | |
| # not topic (or vice-versa), while still keeping retrieval directional. | |
| soft_conditions = [] | |
| if result.get("category") and result.get("category_conf", 0) >= self.filter_threshold["category"]: | |
| soft_conditions.append({"category": result["category"]}) | |
| else: | |
| soft_conditions.append({"category": "general"}) | |
| if result.get("topic") and result.get("topic_conf", 0) >= self.filter_threshold["topic"]: | |
| soft_conditions.append({"topic": result["topic"]}) | |
| else: | |
| soft_conditions.append({"topic": "general"}) | |
| # Build final filter | |
| # Case 1: No soft hints — filter on hard anchors only (broad query like "list all departments") | |
| if not soft_conditions: | |
| if len(hard_conditions) == 1: | |
| return hard_conditions[0] | |
| return {"$and": hard_conditions} | |
| # Case 2: One soft hint — add it directly to the AND (no $or needed) | |
| if len(soft_conditions) == 1: | |
| return {"$and": hard_conditions + soft_conditions} | |
| # Case 3: Both category and topic are confident — combine as $or inside the AND | |
| # Final shape: type AND intent AND (category OR topic) | |
| return {"$and": hard_conditions + [{"$or": soft_conditions}]} | |
| def predict_with_filter(self, queries): | |
| filters = self.predict(queries)[0] | |
| return self._build_filter(filters) | |
| def expand_abbreviations(self, text): | |
| text = text.lower().strip() | |
| for abbr, full in self.abbreviations.items(): | |
| pattern = r'\b' + re.escape(abbr.lower()) + r'\b' # ← lowercase the key too | |
| text = re.sub(pattern, full, text) | |
| return text | |
| def get_features(self, queries): | |
| queries_clean = [self.expand_abbreviations(q) for q in queries] | |
| embeddings = self.embedding_model.encode( | |
| queries_clean, show_progress_bar=False | |
| ) | |
| if not hasattr(self.tfidf, "vocabulary_"): | |
| tfidf_features = self.tfidf.fit_transform(queries_clean).toarray() | |
| else: | |
| tfidf_features = self.tfidf.transform(queries_clean).toarray() | |
| return np.hstack([embeddings, tfidf_features]) | |
| def train_single(self, X, y, field, C=0.01): | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, | |
| test_size=0.2, | |
| random_state=42, | |
| stratify=y | |
| ) | |
| clf = LogisticRegression( | |
| C=C, | |
| penalty="l2", | |
| solver="lbfgs", | |
| max_iter=2000, | |
| class_weight="balanced", | |
| random_state=42 | |
| ) | |
| clf.fit(X_train, y_train) | |
| train_acc = clf.score(X_train, y_train) | |
| test_acc = clf.score(X_test, y_test) | |
| cv_scores = cross_val_score(clf, X, y, cv=5) | |
| print(f"\n{field.upper()}:") | |
| print(f"Train: {train_acc:.3f} | Test: {test_acc:.3f} | CV: {cv_scores.mean():.3f}") | |
| return clf | |
| def train_models(self, df): | |
| X = self.get_features(df["question"].tolist()) | |
| self.models["type"] = self.train_single( | |
| X, df["type"].values, "type", C=0.01 | |
| ) | |
| self.models["category"] = self.train_single( | |
| X, df["category"].values, "category", C=0.005 | |
| ) | |
| self.models["topic"] = self.train_single( | |
| X, df["topic"].values, "topic", C=0.005 | |
| ) | |
| self.models["intent"] = self.train_single( | |
| X, df["intent"].values, "intent", C=0.005 | |
| ) | |
| return self.models | |
| def predict(self, queries: List[str], enforce_constraints=True): | |
| X = self.get_features(queries) | |
| results = [] | |
| for i, query in enumerate(queries): | |
| res = {"question": query} | |
| # ---------- TYPE ---------- | |
| type_proba = self.models["type"].predict_proba([X[i]])[0] | |
| type_classes = self.models["type"].classes_ | |
| type_idx = np.argmax(type_proba) | |
| type_pred = type_classes[type_idx] | |
| res["type"] = self.le_type.inverse_transform([type_pred])[0] | |
| res["type_conf"] = float(type_proba[type_idx]) | |
| # ---------- CATEGORY ---------- | |
| category_proba = self.models["category"].predict_proba([X[i]])[0] | |
| category_classes = self.models["category"].classes_ | |
| if enforce_constraints: | |
| category_labels = self.le_category.inverse_transform(category_classes) | |
| allowed = set(self.master_index[res["type"]]["categories"]) | |
| filtered = [ | |
| (label, prob) | |
| for label, prob in zip(category_labels, category_proba) | |
| if label in allowed | |
| ] | |
| if filtered: | |
| best_category, best_prob = max(filtered, key=lambda x: x[1]) | |
| else: | |
| idx = np.argmax(category_proba) | |
| best_category = category_labels[idx] | |
| best_prob = category_proba[idx] | |
| res["category"] = best_category | |
| res["category_conf"] = float(best_prob) | |
| else: | |
| idx = np.argmax(category_proba) | |
| pred = category_classes[idx] | |
| res["category"] = self.le_category.inverse_transform([pred])[0] | |
| res["category_conf"] = float(category_proba[idx]) | |
| # ---------- TOPIC ---------- | |
| topic_proba = self.models["topic"].predict_proba([X[i]])[0] | |
| topic_classes = self.models["topic"].classes_ | |
| if enforce_constraints: | |
| topic_labels = self.le_topic.inverse_transform(topic_classes) | |
| allowed = set(self.master_index[res["type"]]["topics"]) | |
| filtered = [ | |
| (label, prob) | |
| for label, prob in zip(topic_labels, topic_proba) | |
| if label in allowed | |
| ] | |
| if filtered: | |
| best_topic, best_prob = max(filtered, key=lambda x: x[1]) | |
| else: | |
| idx = np.argmax(topic_proba) | |
| best_topic = topic_labels[idx] | |
| best_prob = topic_proba[idx] | |
| res["topic"] = best_topic | |
| res["topic_conf"] = float(best_prob) | |
| else: | |
| idx = np.argmax(topic_proba) | |
| pred = topic_classes[idx] | |
| res["topic"] = self.le_topic.inverse_transform([pred])[0] | |
| res["topic_conf"] = float(topic_proba[idx]) | |
| # ---------- INTENT ---------- | |
| intent_proba = self.models["intent"].predict_proba([X[i]])[0] | |
| intent_classes = self.models["intent"].classes_ | |
| intent_idx = np.argmax(intent_proba) | |
| intent_pred = intent_classes[intent_idx] | |
| res["intent"] = self.le_intent.inverse_transform([intent_pred])[0] | |
| res["intent_conf"] = float(intent_proba[intent_idx]) | |
| if res["type_conf"] < self.threshold["type"]: | |
| res["type"] = None | |
| res["type_conf"] = 0 | |
| if res["category_conf"] < self.threshold["category"]: | |
| res["category"] = None | |
| res["category_conf"] = 0 | |
| if res["topic_conf"] < self.threshold["topic"]: | |
| res["topic"] = None | |
| res["topic_conf"] = 0 | |
| if res["intent_conf"] < self.threshold["intent"]: | |
| res["intent"] = None | |
| res["intent_conf"] = 0 | |
| print("=" * 50) | |
| print(query) | |
| print(f"Type: {res['type']}, {res['type_conf']}") | |
| print(f"Category: {res['category']}, {res['category_conf']}") | |
| print(f"Topic: {res['topic']}, {res['topic_conf']}") | |
| print(f"Intent: {res['intent']}, {res['intent_conf']}") | |
| print("=" * 50) | |
| results.append(res) | |
| return results | |
| classifier_path = settings.classifier_path / "chatbot_classifier.pkl" | |
| pipeline = load_pipeline(classifier_path) | |
| models = pipeline["models"] | |
| tfidf = pipeline["tfidf"] | |
| le_type = pipeline["le_type"] | |
| le_category = pipeline["le_category"] | |
| le_topic = pipeline["le_topic"] | |
| le_intent = pipeline["le_intent"] | |
| MASTER_INDEX = pipeline["MASTER_INDEX"] | |
| ABBREVIATIONS = pipeline["ABBREVIATIONS"] | |
| clf = Classifier( | |
| tfidf=tfidf, | |
| abbreviations=ABBREVIATIONS, | |
| master_index=MASTER_INDEX, | |
| le_type=le_type, | |
| le_category=le_category, | |
| le_topic=le_topic, | |
| le_intent=le_intent, | |
| models=models | |
| ) |