Spaces:
Runtime error
Runtime error
Upload 13 files
Browse files- .gitkeep +0 -0
- Dockerfile +32 -0
- Procfile +1 -0
- README.md +12 -13
- app.py +255 -0
- hybrid_interest_classifier.pkl +3 -0
- hybrid_interest_classifier.py +501 -0
- hybrid_model_debugger.py +152 -0
- requirements.txt +10 -0
- runtime.txt +1 -0
- space.yaml +2 -0
- survey_interest_dataset_enhanced.csv +0 -0
- utils.py +67 -0
.gitkeep
ADDED
|
File without changes
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use official Python image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Avoid prompts from pip
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
# Set working directory
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Install system dependencies (for torch, transformers, etc.)
|
| 12 |
+
RUN apt-get update && apt-get install -y \
|
| 13 |
+
build-essential \
|
| 14 |
+
libglib2.0-0 \
|
| 15 |
+
libsm6 \
|
| 16 |
+
libxext6 \
|
| 17 |
+
libxrender-dev \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# Copy and install Python dependencies
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
RUN pip install --upgrade pip
|
| 23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 24 |
+
|
| 25 |
+
# Copy app files
|
| 26 |
+
COPY . .
|
| 27 |
+
|
| 28 |
+
# Expose port for FastAPI
|
| 29 |
+
EXPOSE 8000
|
| 30 |
+
|
| 31 |
+
# Default command
|
| 32 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: uvicorn app:app --host=0.0.0.0 --port=${PORT}
|
README.md
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# Hybrid Interest Classifier API
|
| 2 |
+
|
| 3 |
+
This Hugging Face Space hosts a FastAPI-based machine learning API that predicts user interests (Music, Food, Travel, etc.) from free-text input. It uses a hybrid model combining TF-IDF + BERT zero-shot classification.
|
| 4 |
+
|
| 5 |
+
Try it by sending a POST request to `/predict` with:
|
| 6 |
+
```json
|
| 7 |
+
{
|
| 8 |
+
"text": "I love hiking and coding!",
|
| 9 |
+
"alpha": 0.6,
|
| 10 |
+
"threshold": 0.5,
|
| 11 |
+
"return_scores": true
|
| 12 |
+
}
|
|
|
app.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import pickle
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Optional, Dict, Any, Union
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(
|
| 14 |
+
level=logging.INFO,
|
| 15 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 16 |
+
)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Initialize FastAPI app
|
| 20 |
+
app = FastAPI()
|
| 21 |
+
|
| 22 |
+
# Allow CORS
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"],
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Define InterestClassifier class here (or import it if available)
|
| 32 |
+
class InterestClassifier:
|
| 33 |
+
"""
|
| 34 |
+
Hybrid Interest Classification model that combines TF-IDF with BERT zero-shot classification
|
| 35 |
+
This is a simplified version for compatibility with the API
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, model_path=None, alpha=0.6, threshold=0.5):
|
| 38 |
+
self.alpha = alpha
|
| 39 |
+
self.threshold = threshold
|
| 40 |
+
self.tfidf_pipeline = None
|
| 41 |
+
self.mlb = None
|
| 42 |
+
self.bert_classifier = None
|
| 43 |
+
|
| 44 |
+
if model_path:
|
| 45 |
+
self.load_model(model_path)
|
| 46 |
+
|
| 47 |
+
def load_model(self, path):
|
| 48 |
+
"""Load a saved model from disk"""
|
| 49 |
+
try:
|
| 50 |
+
with open(path, 'rb') as f:
|
| 51 |
+
components = pickle.load(f)
|
| 52 |
+
|
| 53 |
+
self.tfidf_pipeline = components.get('tfidf_pipeline')
|
| 54 |
+
self.mlb = components.get('mlb')
|
| 55 |
+
self.alpha = components.get('alpha', 0.6)
|
| 56 |
+
self.threshold = components.get('threshold', 0.5)
|
| 57 |
+
|
| 58 |
+
logger.info(f"Model components loaded from {path}")
|
| 59 |
+
logger.info(f"Model components: {list(components.keys())}")
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Failed to load model: {e}")
|
| 63 |
+
raise
|
| 64 |
+
|
| 65 |
+
def predict(self, texts, alpha=None, threshold=None, return_scores=False):
|
| 66 |
+
"""Predict method adapted for the API"""
|
| 67 |
+
if not isinstance(texts, list):
|
| 68 |
+
texts = [texts]
|
| 69 |
+
|
| 70 |
+
# Use instance values if not provided
|
| 71 |
+
alpha = alpha if alpha is not None else self.alpha
|
| 72 |
+
threshold = threshold if threshold is not None else self.threshold
|
| 73 |
+
|
| 74 |
+
if self.tfidf_pipeline is None:
|
| 75 |
+
raise ValueError("TF-IDF pipeline not loaded. Cannot make predictions.")
|
| 76 |
+
|
| 77 |
+
# Get predictions from TF-IDF pipeline
|
| 78 |
+
text = texts[0] # Just use the first text for simplicity
|
| 79 |
+
|
| 80 |
+
# Get raw prediction probabilities
|
| 81 |
+
y_proba = self.tfidf_pipeline.predict_proba([text])
|
| 82 |
+
|
| 83 |
+
# Convert to dictionary of label -> score
|
| 84 |
+
scores = {}
|
| 85 |
+
for i, label in enumerate(self.mlb.classes_):
|
| 86 |
+
# For MultiOutputClassifier, each element of y_proba is a list of arrays
|
| 87 |
+
# Each array is for one label and has 2 values: [prob_for_0, prob_for_1]
|
| 88 |
+
scores[label] = y_proba[i][0][1] # Get probability of positive class
|
| 89 |
+
|
| 90 |
+
# Apply threshold to get labels
|
| 91 |
+
labels = [label for label, score in scores.items() if score >= threshold]
|
| 92 |
+
|
| 93 |
+
if return_scores:
|
| 94 |
+
# Sort scores for easier interpretation
|
| 95 |
+
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
'labels': labels,
|
| 99 |
+
'scores': scores,
|
| 100 |
+
'sorted_scores': sorted_scores,
|
| 101 |
+
'alpha': alpha,
|
| 102 |
+
'threshold': threshold
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return labels
|
| 106 |
+
|
| 107 |
+
# Load the hybrid classifier
|
| 108 |
+
MODEL_PATH = "hybrid_interest_classifier.pkl"
|
| 109 |
+
hybrid_classifier = None
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
logger.info(f"Loading hybrid model from {MODEL_PATH}")
|
| 113 |
+
# Create an instance of our classifier and load the model
|
| 114 |
+
hybrid_classifier = InterestClassifier(model_path=MODEL_PATH)
|
| 115 |
+
logger.info("Hybrid model loaded successfully")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Failed to load hybrid model: {e}")
|
| 118 |
+
|
| 119 |
+
# Define keyword-based interest detection as fallback
|
| 120 |
+
def keyword_interests(text):
|
| 121 |
+
"""
|
| 122 |
+
Determine interests using keyword matching as a fallback
|
| 123 |
+
"""
|
| 124 |
+
text = text.lower()
|
| 125 |
+
interests = []
|
| 126 |
+
|
| 127 |
+
if any(word in text for word in ['music', 'band', 'concert', 'sing', 'guitar', 'song']):
|
| 128 |
+
interests.append('Music')
|
| 129 |
+
|
| 130 |
+
if any(word in text for word in ['food', 'cook', 'recipe', 'restaurant', 'eat', 'cuisine']):
|
| 131 |
+
interests.append('Food')
|
| 132 |
+
|
| 133 |
+
if any(word in text for word in ['sport', 'gym', 'fitness', 'exercise', 'workout', 'run']):
|
| 134 |
+
interests.append('Sports')
|
| 135 |
+
|
| 136 |
+
if any(word in text for word in ['art', 'paint', 'draw', 'gallery', 'museum', 'exhibition']):
|
| 137 |
+
interests.append('Arts')
|
| 138 |
+
|
| 139 |
+
if any(word in text for word in ['tech', 'code', 'software', 'computer', 'programming']):
|
| 140 |
+
interests.append('Technology')
|
| 141 |
+
|
| 142 |
+
if any(word in text for word in ['learn', 'study', 'course', 'book', 'read', 'class']):
|
| 143 |
+
interests.append('Education')
|
| 144 |
+
|
| 145 |
+
if any(word in text for word in ['travel', 'trip', 'journey', 'explore', 'hike', 'tourism']):
|
| 146 |
+
interests.append('Travel')
|
| 147 |
+
|
| 148 |
+
if not interests:
|
| 149 |
+
interests.append('No specific interests detected')
|
| 150 |
+
|
| 151 |
+
return interests
|
| 152 |
+
|
| 153 |
+
# Pydantic models
|
| 154 |
+
class PredictionRequest(BaseModel):
|
| 155 |
+
text: str
|
| 156 |
+
alpha: Optional[float] = None
|
| 157 |
+
threshold: Optional[float] = None
|
| 158 |
+
return_scores: Optional[bool] = False
|
| 159 |
+
|
| 160 |
+
@app.get("/")
|
| 161 |
+
async def root():
|
| 162 |
+
"""Root endpoint to check if API is running"""
|
| 163 |
+
return {
|
| 164 |
+
"status": "online",
|
| 165 |
+
"message": "Hybrid Interest Classifier API is running",
|
| 166 |
+
"model_loaded": hybrid_classifier is not None
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
@app.get("/health")
|
| 170 |
+
async def health():
|
| 171 |
+
"""Health check endpoint"""
|
| 172 |
+
return {"status": "healthy", "model_loaded": hybrid_classifier is not None}
|
| 173 |
+
|
| 174 |
+
@app.post("/predict")
|
| 175 |
+
async def predict(request: PredictionRequest):
|
| 176 |
+
"""
|
| 177 |
+
Predict interests based on text input
|
| 178 |
+
"""
|
| 179 |
+
text = request.text
|
| 180 |
+
alpha = request.alpha
|
| 181 |
+
threshold = request.threshold
|
| 182 |
+
return_scores = request.return_scores
|
| 183 |
+
|
| 184 |
+
logger.info(f"Prediction request: text='{text[:50]}...', alpha={alpha}, threshold={threshold}, return_scores={return_scores}")
|
| 185 |
+
|
| 186 |
+
if not text or text.strip() == "":
|
| 187 |
+
return {"labels": ["No text provided"], "text": text}
|
| 188 |
+
|
| 189 |
+
if hybrid_classifier is None:
|
| 190 |
+
logger.warning("Using fallback keyword matching (model not loaded)")
|
| 191 |
+
return {"labels": keyword_interests(text), "text": text}
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
# Prepare prediction parameters
|
| 195 |
+
kwargs = {}
|
| 196 |
+
if alpha is not None:
|
| 197 |
+
kwargs['alpha'] = alpha
|
| 198 |
+
if threshold is not None:
|
| 199 |
+
kwargs['threshold'] = threshold
|
| 200 |
+
if return_scores:
|
| 201 |
+
kwargs['return_scores'] = True
|
| 202 |
+
|
| 203 |
+
# Log the call we're about to make
|
| 204 |
+
logger.info(f"Calling hybrid_classifier.predict([{text[:20]}...], {kwargs})")
|
| 205 |
+
|
| 206 |
+
# Make prediction
|
| 207 |
+
prediction = None
|
| 208 |
+
try:
|
| 209 |
+
# Call predict with the text and kwargs
|
| 210 |
+
prediction = hybrid_classifier.predict([text], **kwargs)
|
| 211 |
+
except TypeError as e:
|
| 212 |
+
# If that fails, try without optional parameters
|
| 213 |
+
logger.warning(f"TypeError with kwargs: {e}. Trying without kwargs.")
|
| 214 |
+
prediction = hybrid_classifier.predict([text])
|
| 215 |
+
|
| 216 |
+
logger.info(f"Raw prediction: {prediction}")
|
| 217 |
+
|
| 218 |
+
# Process the prediction result
|
| 219 |
+
labels = []
|
| 220 |
+
scores = {}
|
| 221 |
+
|
| 222 |
+
# Handle dictionary return type (likely with return_scores=True)
|
| 223 |
+
if isinstance(prediction, dict):
|
| 224 |
+
if 'labels' in prediction:
|
| 225 |
+
labels = prediction['labels']
|
| 226 |
+
|
| 227 |
+
if return_scores and 'sorted_scores' in prediction:
|
| 228 |
+
scores = dict(prediction['sorted_scores'])
|
| 229 |
+
elif return_scores and 'scores' in prediction:
|
| 230 |
+
scores = prediction['scores']
|
| 231 |
+
|
| 232 |
+
# Handle list return type
|
| 233 |
+
elif isinstance(prediction, list):
|
| 234 |
+
labels = prediction
|
| 235 |
+
|
| 236 |
+
# If we still have no labels, use keyword matching
|
| 237 |
+
if not labels:
|
| 238 |
+
logger.warning("No labels detected, using fallback")
|
| 239 |
+
labels = keyword_interests(text)
|
| 240 |
+
|
| 241 |
+
# Construct response
|
| 242 |
+
response = {"labels": labels, "text": text}
|
| 243 |
+
if return_scores and scores:
|
| 244 |
+
response["scores"] = scores
|
| 245 |
+
|
| 246 |
+
logger.info(f"Final response: {response}")
|
| 247 |
+
return response
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"Error during prediction: {e}", exc_info=True)
|
| 251 |
+
return {"labels": keyword_interests(text), "text": text, "error": str(e)}
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
import uvicorn
|
| 255 |
+
uvicorn.run("direct_hybrid_api:app", host="0.0.0.0", port=8000, reload=True)
|
hybrid_interest_classifier.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca06ea5cd77ee26ca71a637e726cc19f1f5ead5593f5e09a192b091de99df95e
|
| 3 |
+
size 296903
|
hybrid_interest_classifier.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 7 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 8 |
+
from sklearn.multioutput import MultiOutputClassifier
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.pipeline import Pipeline
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
from transformers import pipeline
|
| 13 |
+
import torch
|
| 14 |
+
import logging
|
| 15 |
+
import time
|
| 16 |
+
from typing import List, Dict, Tuple, Union, Optional
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# Define interest categories
|
| 23 |
+
INTEREST_CATEGORIES = ["Music", "Food", "Sports", "Technology", "Arts", "Travel", "Education"]
|
| 24 |
+
|
| 25 |
+
class InterestClassifier:
|
| 26 |
+
"""
|
| 27 |
+
Hybrid Interest Classification model that combines TF-IDF with BERT zero-shot classification
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self,
|
| 30 |
+
model_path: Optional[str] = None,
|
| 31 |
+
alpha: float = 0.6,
|
| 32 |
+
threshold: float = 0.5,
|
| 33 |
+
bert_model_name: str = 'facebook/bart-large-mnli',
|
| 34 |
+
use_gpu: bool = torch.cuda.is_available()):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the hybrid classifier
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model_path: Path to a saved model (if None, a new model will be created)
|
| 40 |
+
alpha: Weight for TF-IDF model (1-alpha for BERT)
|
| 41 |
+
threshold: Classification threshold for final predictions
|
| 42 |
+
bert_model_name: Name of the BERT model to use
|
| 43 |
+
use_gpu: Whether to use GPU for BERT inference
|
| 44 |
+
"""
|
| 45 |
+
self.alpha = alpha
|
| 46 |
+
self.threshold = threshold
|
| 47 |
+
self.bert_model_name = bert_model_name
|
| 48 |
+
self.use_gpu = use_gpu
|
| 49 |
+
|
| 50 |
+
# Initialize models as None
|
| 51 |
+
self.tfidf_pipeline = None
|
| 52 |
+
self.mlb = None
|
| 53 |
+
self.bert_classifier = None
|
| 54 |
+
|
| 55 |
+
# Load the model if path is provided
|
| 56 |
+
if model_path and os.path.exists(model_path):
|
| 57 |
+
self.load_model(model_path)
|
| 58 |
+
|
| 59 |
+
# Initialize BERT model
|
| 60 |
+
self._init_bert_classifier()
|
| 61 |
+
|
| 62 |
+
def _improved_preprocess_text(self, text: str) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Enhanced text preprocessing that better preserves domain-specific indicators
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
text: Input text to preprocess
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Preprocessed text
|
| 71 |
+
"""
|
| 72 |
+
# Handle potential NaN values
|
| 73 |
+
if pd.isna(text):
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
# Convert to lowercase
|
| 77 |
+
text = text.lower()
|
| 78 |
+
|
| 79 |
+
# Remove special characters while preserving important separators
|
| 80 |
+
text = re.sub(r'[^\w\s|-]', ' ', text)
|
| 81 |
+
|
| 82 |
+
# Replace multiple spaces with a single space
|
| 83 |
+
text = re.sub(r'\s+', ' ', text)
|
| 84 |
+
|
| 85 |
+
# Define domain terms dictionary
|
| 86 |
+
domain_terms = {
|
| 87 |
+
'music': ['music', 'guitar', 'band', 'concert', 'gig', 'sing', 'song', 'play music', 'musician'],
|
| 88 |
+
'food': ['food', 'cook', 'cuisine', 'recipe', 'restaurant', 'eat', 'culinary', 'bake', 'chef'],
|
| 89 |
+
'sports': ['sport', 'run', 'gym', 'fitness', 'workout', 'exercise', 'athletic', 'training'],
|
| 90 |
+
'arts': ['art', 'paint', 'draw', 'museum', 'gallery', 'exhibit', 'creative', 'design'],
|
| 91 |
+
'technology': ['tech', 'code', 'program', 'software', 'developer', 'computer', 'app', 'digital'],
|
| 92 |
+
'education': ['education', 'learn', 'course', 'class', 'study', 'book', 'read', 'academic'],
|
| 93 |
+
'travel': ['travel', 'trip', 'hike', 'explore', 'tour', 'visit', 'journey', 'destination']
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# Check for domain terms and emphasize them
|
| 97 |
+
modified_text = text
|
| 98 |
+
for category, terms in domain_terms.items():
|
| 99 |
+
for term in terms:
|
| 100 |
+
if term in text:
|
| 101 |
+
# Add the category name explicitly if a related term is found
|
| 102 |
+
modified_text += f" {category} {category} {term} {term}"
|
| 103 |
+
|
| 104 |
+
# Split on common separators but preserve the important phrases
|
| 105 |
+
parts = []
|
| 106 |
+
for part in re.split(r'\s*\|\s*', modified_text):
|
| 107 |
+
# Remove numbers (but keep words with numbers like "web3")
|
| 108 |
+
part = re.sub(r'\b\d+\b', '', part)
|
| 109 |
+
parts.append(part)
|
| 110 |
+
|
| 111 |
+
# Define a more focused stopwords list
|
| 112 |
+
core_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'the', 'a', 'an', 'and', 'but',
|
| 113 |
+
'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with',
|
| 114 |
+
'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after',
|
| 115 |
+
'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over',
|
| 116 |
+
'under', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were'}
|
| 117 |
+
|
| 118 |
+
# Process each part and filter stopwords
|
| 119 |
+
processed_parts = []
|
| 120 |
+
for part in parts:
|
| 121 |
+
words = part.split()
|
| 122 |
+
filtered_words = [word for word in words if word not in core_stopwords]
|
| 123 |
+
|
| 124 |
+
if filtered_words:
|
| 125 |
+
processed_parts.append(' '.join(filtered_words))
|
| 126 |
+
|
| 127 |
+
# Join the processed parts back
|
| 128 |
+
processed_text = ' '.join(processed_parts)
|
| 129 |
+
|
| 130 |
+
return processed_text.strip()
|
| 131 |
+
|
| 132 |
+
def _init_bert_classifier(self):
|
| 133 |
+
"""Initialize the BERT zero-shot classifier"""
|
| 134 |
+
try:
|
| 135 |
+
logger.info(f"Initializing BERT zero-shot classifier with model: {self.bert_model_name}")
|
| 136 |
+
device = 0 if self.use_gpu and torch.cuda.is_available() else -1
|
| 137 |
+
self.bert_classifier = pipeline('zero-shot-classification',
|
| 138 |
+
model=self.bert_model_name,
|
| 139 |
+
device=device)
|
| 140 |
+
logger.info("BERT classifier successfully initialized")
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Failed to initialize BERT classifier: {e}")
|
| 143 |
+
logger.warning("Proceeding without BERT - will use TF-IDF only")
|
| 144 |
+
self.bert_classifier = None
|
| 145 |
+
|
| 146 |
+
def train(self,
|
| 147 |
+
df: pd.DataFrame,
|
| 148 |
+
text_column: str = 'survey_answer',
|
| 149 |
+
labels_column: str = 'labels_list',
|
| 150 |
+
test_size: float = 0.2):
|
| 151 |
+
"""
|
| 152 |
+
Train the TF-IDF + Logistic Regression model
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
df: DataFrame containing survey responses and labels
|
| 156 |
+
text_column: Column name containing the survey responses
|
| 157 |
+
labels_column: Column name containing the labels
|
| 158 |
+
test_size: Proportion of data to use for testing
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Evaluation metrics on test set
|
| 162 |
+
"""
|
| 163 |
+
logger.info("Starting model training...")
|
| 164 |
+
|
| 165 |
+
# Prepare labels
|
| 166 |
+
if isinstance(df[labels_column].iloc[0], str):
|
| 167 |
+
logger.info("Converting labels from string to list...")
|
| 168 |
+
# Convert string representation of lists to actual lists
|
| 169 |
+
df[labels_column] = df[labels_column].str.strip('[]').str.split(',')
|
| 170 |
+
# Clean up any extra quotes or spaces
|
| 171 |
+
df[labels_column] = df[labels_column].apply(lambda x: [item.strip().strip("'\"") for item in x])
|
| 172 |
+
|
| 173 |
+
# Preprocess text
|
| 174 |
+
logger.info("Preprocessing text data...")
|
| 175 |
+
df['processed_text'] = df[text_column].apply(self._improved_preprocess_text)
|
| 176 |
+
|
| 177 |
+
# Initialize MultiLabelBinarizer
|
| 178 |
+
self.mlb = MultiLabelBinarizer(classes=INTEREST_CATEGORIES)
|
| 179 |
+
y = self.mlb.fit_transform(df[labels_column])
|
| 180 |
+
logger.info(f"Target shape: {y.shape}")
|
| 181 |
+
|
| 182 |
+
# Split data
|
| 183 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 184 |
+
df['processed_text'], y, test_size=test_size, random_state=42, shuffle=True
|
| 185 |
+
)
|
| 186 |
+
logger.info(f"Training set: {X_train.shape[0]} samples, Test set: {X_test.shape[0]} samples")
|
| 187 |
+
|
| 188 |
+
# Create TF-IDF pipeline
|
| 189 |
+
logger.info("Creating and training TF-IDF pipeline...")
|
| 190 |
+
tfidf_vectorizer = TfidfVectorizer(
|
| 191 |
+
max_features=3000,
|
| 192 |
+
min_df=2,
|
| 193 |
+
max_df=0.9,
|
| 194 |
+
ngram_range=(1, 3),
|
| 195 |
+
sublinear_tf=True
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
lr_clf = LogisticRegression(
|
| 199 |
+
C=0.5,
|
| 200 |
+
max_iter=1000,
|
| 201 |
+
class_weight='balanced',
|
| 202 |
+
solver='liblinear',
|
| 203 |
+
penalty='l2'
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
multi_lr = MultiOutputClassifier(lr_clf)
|
| 207 |
+
|
| 208 |
+
self.tfidf_pipeline = Pipeline([
|
| 209 |
+
('tfidf', tfidf_vectorizer),
|
| 210 |
+
('classifier', multi_lr)
|
| 211 |
+
])
|
| 212 |
+
|
| 213 |
+
# Train the pipeline
|
| 214 |
+
self.tfidf_pipeline.fit(X_train, y_train)
|
| 215 |
+
logger.info("TF-IDF pipeline trained successfully")
|
| 216 |
+
|
| 217 |
+
# Evaluate on test set
|
| 218 |
+
logger.info("Evaluating model on test set...")
|
| 219 |
+
y_pred = self.tfidf_pipeline.predict(X_test)
|
| 220 |
+
|
| 221 |
+
# Calculate metrics
|
| 222 |
+
from sklearn.metrics import hamming_loss, f1_score, precision_score, recall_score
|
| 223 |
+
h_loss = hamming_loss(y_test, y_pred)
|
| 224 |
+
micro_f1 = f1_score(y_test, y_pred, average='micro')
|
| 225 |
+
macro_f1 = f1_score(y_test, y_pred, average='macro')
|
| 226 |
+
|
| 227 |
+
logger.info(f"Hamming Loss: {h_loss:.4f}")
|
| 228 |
+
logger.info(f"Micro F1 Score: {micro_f1:.4f}")
|
| 229 |
+
logger.info(f"Macro F1 Score: {macro_f1:.4f}")
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
'hamming_loss': h_loss,
|
| 233 |
+
'micro_f1': micro_f1,
|
| 234 |
+
'macro_f1': macro_f1
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
def get_tfidf_predictions(self, text: str) -> Dict[str, float]:
|
| 238 |
+
"""
|
| 239 |
+
Get predictions from TF-IDF model with confidence scores
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
text: The input text to classify
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Dictionary of label -> score
|
| 246 |
+
"""
|
| 247 |
+
if self.tfidf_pipeline is None:
|
| 248 |
+
raise ValueError("TF-IDF model is not trained yet. Call train() first.")
|
| 249 |
+
|
| 250 |
+
# Preprocess text
|
| 251 |
+
processed_text = self._improved_preprocess_text(text)
|
| 252 |
+
|
| 253 |
+
# Get raw prediction probabilities
|
| 254 |
+
y_proba = self.tfidf_pipeline.predict_proba([processed_text])
|
| 255 |
+
|
| 256 |
+
# Convert to dictionary of label -> score
|
| 257 |
+
scores = {}
|
| 258 |
+
for i, label in enumerate(self.mlb.classes_):
|
| 259 |
+
# For MultiOutputClassifier, each element of y_proba is a list of arrays
|
| 260 |
+
# Each array is for one label and has 2 values: [prob_for_0, prob_for_1]
|
| 261 |
+
scores[label] = y_proba[i][0][1] # Get probability of positive class
|
| 262 |
+
|
| 263 |
+
return scores
|
| 264 |
+
|
| 265 |
+
def get_bert_predictions(self, text: str) -> Dict[str, float]:
|
| 266 |
+
"""
|
| 267 |
+
Get predictions from BERT model
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
text: The input text to classify
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Dictionary of label -> score
|
| 274 |
+
"""
|
| 275 |
+
if self.bert_classifier is None:
|
| 276 |
+
logger.warning("BERT classifier is not available, returning empty scores")
|
| 277 |
+
return {label: 0.0 for label in INTEREST_CATEGORIES}
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
# Use the BERT zero-shot classifier
|
| 281 |
+
result = self.bert_classifier(text, INTEREST_CATEGORIES, multi_label=True)
|
| 282 |
+
|
| 283 |
+
# Convert to dictionary of label -> score
|
| 284 |
+
scores = dict(zip(result['labels'], result['scores']))
|
| 285 |
+
|
| 286 |
+
# Ensure all categories are present (BERT may return in different order)
|
| 287 |
+
for category in INTEREST_CATEGORIES:
|
| 288 |
+
if category not in scores:
|
| 289 |
+
scores[category] = 0.0
|
| 290 |
+
|
| 291 |
+
return scores
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.error(f"Error in BERT prediction: {e}")
|
| 295 |
+
return {label: 0.0 for label in INTEREST_CATEGORIES}
|
| 296 |
+
|
| 297 |
+
def predict(self,
|
| 298 |
+
text: str,
|
| 299 |
+
alpha: Optional[float] = None,
|
| 300 |
+
threshold: Optional[float] = None,
|
| 301 |
+
return_scores: bool = False) -> Union[List[str], Dict]:
|
| 302 |
+
"""
|
| 303 |
+
Combine TF-IDF and BERT predictions using weighted average
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
text: The input text to classify
|
| 307 |
+
alpha: Weight for TF-IDF predictions (1-alpha for BERT), uses self.alpha if None
|
| 308 |
+
threshold: Threshold for classification, uses self.threshold if None
|
| 309 |
+
return_scores: Whether to return scores along with labels
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Either a list of predicted labels or a dictionary with labels and scores
|
| 313 |
+
"""
|
| 314 |
+
if self.tfidf_pipeline is None:
|
| 315 |
+
raise ValueError("Model is not trained yet. Call train() first.")
|
| 316 |
+
|
| 317 |
+
# Use instance values if not provided
|
| 318 |
+
alpha = alpha if alpha is not None else self.alpha
|
| 319 |
+
threshold = threshold if threshold is not None else self.threshold
|
| 320 |
+
|
| 321 |
+
# Time the predictions
|
| 322 |
+
start_time = time.time()
|
| 323 |
+
|
| 324 |
+
# Get TF-IDF predictions
|
| 325 |
+
tfidf_scores = self.get_tfidf_predictions(text)
|
| 326 |
+
tfidf_time = time.time() - start_time
|
| 327 |
+
|
| 328 |
+
# Get BERT predictions if available
|
| 329 |
+
bert_time_start = time.time()
|
| 330 |
+
if self.bert_classifier is not None:
|
| 331 |
+
bert_scores = self.get_bert_predictions(text)
|
| 332 |
+
use_bert = True
|
| 333 |
+
else:
|
| 334 |
+
bert_scores = {category: 0.0 for category in INTEREST_CATEGORIES}
|
| 335 |
+
use_bert = False
|
| 336 |
+
logger.warning("BERT classifier not available, using TF-IDF only")
|
| 337 |
+
bert_time = time.time() - bert_time_start
|
| 338 |
+
|
| 339 |
+
# Combine predictions
|
| 340 |
+
combined_scores = {}
|
| 341 |
+
final_labels = []
|
| 342 |
+
|
| 343 |
+
for category in INTEREST_CATEGORIES:
|
| 344 |
+
# Get scores from both models
|
| 345 |
+
tfidf_score = tfidf_scores.get(category, 0.0)
|
| 346 |
+
bert_score = bert_scores.get(category, 0.0)
|
| 347 |
+
|
| 348 |
+
# Weighted average (if using BERT)
|
| 349 |
+
if use_bert:
|
| 350 |
+
final_score = (alpha * tfidf_score) + ((1 - alpha) * bert_score)
|
| 351 |
+
else:
|
| 352 |
+
final_score = tfidf_score
|
| 353 |
+
|
| 354 |
+
combined_scores[category] = final_score
|
| 355 |
+
|
| 356 |
+
# Apply threshold
|
| 357 |
+
if final_score >= threshold:
|
| 358 |
+
final_labels.append(category)
|
| 359 |
+
|
| 360 |
+
total_time = time.time() - start_time
|
| 361 |
+
|
| 362 |
+
if return_scores:
|
| 363 |
+
# Sort scores for easier interpretation
|
| 364 |
+
sorted_scores = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
'labels': final_labels,
|
| 368 |
+
'scores': combined_scores,
|
| 369 |
+
'sorted_scores': sorted_scores,
|
| 370 |
+
'tfidf_scores': tfidf_scores,
|
| 371 |
+
'bert_scores': bert_scores,
|
| 372 |
+
'timing': {
|
| 373 |
+
'tfidf': tfidf_time,
|
| 374 |
+
'bert': bert_time,
|
| 375 |
+
'total': total_time
|
| 376 |
+
},
|
| 377 |
+
'alpha': alpha,
|
| 378 |
+
'threshold': threshold,
|
| 379 |
+
'using_bert': use_bert
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return final_labels
|
| 383 |
+
|
| 384 |
+
def save_model(self, path: str = "hybrid_interest_classifier.pkl"):
|
| 385 |
+
"""
|
| 386 |
+
Save the model to disk
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
path: Path to save the model
|
| 390 |
+
"""
|
| 391 |
+
if self.tfidf_pipeline is None:
|
| 392 |
+
raise ValueError("Model is not trained yet. Call train() first.")
|
| 393 |
+
|
| 394 |
+
# Note: We only save the TF-IDF pipeline and MLBinarizer
|
| 395 |
+
# BERT will be re-initialized on load
|
| 396 |
+
components = {
|
| 397 |
+
'tfidf_pipeline': self.tfidf_pipeline,
|
| 398 |
+
'mlb': self.mlb,
|
| 399 |
+
'alpha': self.alpha,
|
| 400 |
+
'threshold': self.threshold,
|
| 401 |
+
'bert_model_name': self.bert_model_name,
|
| 402 |
+
'interest_categories': INTEREST_CATEGORIES,
|
| 403 |
+
'version': '1.0'
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
with open(path, 'wb') as f:
|
| 407 |
+
pickle.dump(components, f)
|
| 408 |
+
|
| 409 |
+
logger.info(f"Model saved to {path}")
|
| 410 |
+
|
| 411 |
+
def load_model(self, path: str):
|
| 412 |
+
"""
|
| 413 |
+
Load a saved model from disk
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
path: Path to the saved model
|
| 417 |
+
"""
|
| 418 |
+
try:
|
| 419 |
+
with open(path, 'rb') as f:
|
| 420 |
+
components = pickle.load(f)
|
| 421 |
+
|
| 422 |
+
self.tfidf_pipeline = components['tfidf_pipeline']
|
| 423 |
+
self.mlb = components['mlb']
|
| 424 |
+
self.alpha = components.get('alpha', 0.6)
|
| 425 |
+
self.threshold = components.get('threshold', 0.5)
|
| 426 |
+
self.bert_model_name = components.get('bert_model_name', 'facebook/bart-large-mnli')
|
| 427 |
+
|
| 428 |
+
logger.info(f"Model loaded from {path}")
|
| 429 |
+
|
| 430 |
+
# Re-initialize BERT classifier
|
| 431 |
+
self._init_bert_classifier()
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Failed to load model: {e}")
|
| 435 |
+
raise
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# Example usage
|
| 439 |
+
def main():
|
| 440 |
+
try:
|
| 441 |
+
# Load dataset
|
| 442 |
+
logger.info("Loading dataset: survey_interest_dataset_enhanced.csv")
|
| 443 |
+
df = pd.read_csv('survey_interest_dataset_enhanced.csv')
|
| 444 |
+
|
| 445 |
+
# Convert labels_list if it's a string representation
|
| 446 |
+
if 'labels_list' in df.columns and isinstance(df['labels_list'].iloc[0], str):
|
| 447 |
+
logger.info("Converting labels_list from string to list...")
|
| 448 |
+
df['labels_list'] = df['labels_list'].str.strip('[]').str.split(',')
|
| 449 |
+
df['labels_list'] = df['labels_list'].apply(lambda x: [item.strip().strip("'\"") for item in x])
|
| 450 |
+
|
| 451 |
+
# Initialize classifier
|
| 452 |
+
logger.info("Initializing classifier with alpha=0.6, threshold=0.5")
|
| 453 |
+
classifier = InterestClassifier(alpha=0.6, threshold=0.5)
|
| 454 |
+
|
| 455 |
+
# Train the model
|
| 456 |
+
logger.info("Training the model...")
|
| 457 |
+
metrics = classifier.train(df)
|
| 458 |
+
logger.info(f"Training metrics: {metrics}")
|
| 459 |
+
|
| 460 |
+
# Save the model
|
| 461 |
+
model_path = "hybrid_interest_classifier.pkl"
|
| 462 |
+
logger.info(f"Saving model to {model_path}")
|
| 463 |
+
classifier.save_model(model_path)
|
| 464 |
+
|
| 465 |
+
# Test on some examples
|
| 466 |
+
test_examples = [
|
| 467 |
+
"I love hiking in the mountains and trying local foods wherever I travel.",
|
| 468 |
+
"I'm a software developer who plays guitar in a band on weekends.",
|
| 469 |
+
"I spend most of my time reading books and attending online courses.",
|
| 470 |
+
"I enjoy painting landscapes and visiting art museums when I travel."
|
| 471 |
+
]
|
| 472 |
+
|
| 473 |
+
logger.info("Testing model on example inputs...")
|
| 474 |
+
for example in test_examples:
|
| 475 |
+
result = classifier.predict(example, return_scores=True)
|
| 476 |
+
logger.info(f"\nExample: '{example}'")
|
| 477 |
+
logger.info(f"Predicted interests: {result['labels']}")
|
| 478 |
+
logger.info("Top interests by score:")
|
| 479 |
+
for category, score in result['sorted_scores'][:3]:
|
| 480 |
+
logger.info(f" {category}: {score:.4f}")
|
| 481 |
+
|
| 482 |
+
# Fine-tuning alpha parameter demo
|
| 483 |
+
logger.info("\nFine-tuning alpha parameter:")
|
| 484 |
+
example = "I work as a software developer and enjoy hiking on weekends"
|
| 485 |
+
for alpha in [0.3, 0.5, 0.7, 0.9]:
|
| 486 |
+
result = classifier.predict(example, alpha=alpha, return_scores=True)
|
| 487 |
+
logger.info(f"\nAlpha = {alpha} (TF-IDF weight: {alpha}, BERT weight: {1-alpha})")
|
| 488 |
+
logger.info(f"Predicted interests: {result['labels']}")
|
| 489 |
+
logger.info("Top 3 scores:")
|
| 490 |
+
for category, score in result['sorted_scores'][:3]:
|
| 491 |
+
logger.info(f" {category}: {score:.4f}")
|
| 492 |
+
|
| 493 |
+
logger.info("Model training and evaluation completed successfully")
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logger.error(f"Error in main function: {e}", exc_info=True)
|
| 497 |
+
raise
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
main()
|
hybrid_model_debugger.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hybrid_model_debugger.py
|
| 2 |
+
import pickle
|
| 3 |
+
import numpy as np
|
| 4 |
+
import sys
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
def debug_model(model_path, test_text):
|
| 8 |
+
"""
|
| 9 |
+
Debugs the hybrid model by running a detailed test prediction and inspecting the outputs
|
| 10 |
+
at each stage of the process
|
| 11 |
+
"""
|
| 12 |
+
print(f"Loading model from {model_path}...")
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
# Load model
|
| 16 |
+
with open(model_path, "rb") as f:
|
| 17 |
+
model_data = pickle.load(f)
|
| 18 |
+
|
| 19 |
+
print(f"Model loaded successfully. Type: {type(model_data)}")
|
| 20 |
+
|
| 21 |
+
# Determine the type of model
|
| 22 |
+
if isinstance(model_data, dict):
|
| 23 |
+
print("\nModel is a dictionary with keys:")
|
| 24 |
+
for key in model_data:
|
| 25 |
+
print(f" - {key} ({type(model_data[key])})")
|
| 26 |
+
|
| 27 |
+
# Look for classifier in the dictionary
|
| 28 |
+
classifier = None
|
| 29 |
+
if 'model' in model_data:
|
| 30 |
+
classifier = model_data['model']
|
| 31 |
+
print("Using 'model' key as classifier")
|
| 32 |
+
elif 'classifier' in model_data:
|
| 33 |
+
classifier = model_data['classifier']
|
| 34 |
+
print("Using 'classifier' key as classifier")
|
| 35 |
+
else:
|
| 36 |
+
# Try to find a component with predict method
|
| 37 |
+
for key, component in model_data.items():
|
| 38 |
+
if hasattr(component, 'predict'):
|
| 39 |
+
classifier = component
|
| 40 |
+
print(f"Using '{key}' as classifier (has predict method)")
|
| 41 |
+
break
|
| 42 |
+
else:
|
| 43 |
+
# Direct classifier
|
| 44 |
+
classifier = model_data
|
| 45 |
+
print("Model is a direct classifier object")
|
| 46 |
+
|
| 47 |
+
if not classifier:
|
| 48 |
+
print("ERROR: Could not identify a classifier component in the model")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
# Check for mlb
|
| 52 |
+
mlb = None
|
| 53 |
+
if hasattr(classifier, 'mlb'):
|
| 54 |
+
mlb = classifier.mlb
|
| 55 |
+
print("\nFound MultiLabelBinarizer on classifier")
|
| 56 |
+
if hasattr(mlb, 'classes_'):
|
| 57 |
+
print(f"Available classes: {mlb.classes_}")
|
| 58 |
+
else:
|
| 59 |
+
print("WARNING: MultiLabelBinarizer has no classes_ attribute")
|
| 60 |
+
else:
|
| 61 |
+
print("\nNo MultiLabelBinarizer found on classifier")
|
| 62 |
+
|
| 63 |
+
# Check if mlb is in the dictionary
|
| 64 |
+
if isinstance(model_data, dict) and 'mlb' in model_data:
|
| 65 |
+
mlb = model_data['mlb']
|
| 66 |
+
print("Found MultiLabelBinarizer in model dictionary")
|
| 67 |
+
if hasattr(mlb, 'classes_'):
|
| 68 |
+
print(f"Available classes: {mlb.classes_}")
|
| 69 |
+
else:
|
| 70 |
+
print("WARNING: MultiLabelBinarizer has no classes_ attribute")
|
| 71 |
+
|
| 72 |
+
# Check for alpha parameter
|
| 73 |
+
alpha = getattr(classifier, 'alpha', None)
|
| 74 |
+
print(f"\nAlpha parameter: {alpha}")
|
| 75 |
+
|
| 76 |
+
# Check for threshold parameter
|
| 77 |
+
threshold = getattr(classifier, 'threshold', None)
|
| 78 |
+
print(f"Threshold parameter: {threshold}")
|
| 79 |
+
|
| 80 |
+
# Try making a prediction
|
| 81 |
+
print(f"\nTesting prediction with text: '{test_text}'")
|
| 82 |
+
|
| 83 |
+
# Try different prediction approaches
|
| 84 |
+
approaches = [
|
| 85 |
+
("Standard prediction with text as list", lambda: classifier.predict([test_text])),
|
| 86 |
+
("With specific alpha and threshold", lambda: classifier.predict([test_text], alpha=0.6, threshold=0.4)),
|
| 87 |
+
("With return_scores=True", lambda: classifier.predict([test_text], return_scores=True)),
|
| 88 |
+
("All parameters", lambda: classifier.predict([test_text], alpha=0.6, threshold=0.4, return_scores=True))
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
for description, predict_func in approaches:
|
| 92 |
+
print(f"\n--- {description} ---")
|
| 93 |
+
try:
|
| 94 |
+
result = predict_func()
|
| 95 |
+
print(f"Result type: {type(result)}")
|
| 96 |
+
print(f"Result value: {result}")
|
| 97 |
+
|
| 98 |
+
# If it's a numpy array, try to interpret it
|
| 99 |
+
if isinstance(result, np.ndarray):
|
| 100 |
+
print(f"Array shape: {result.shape}")
|
| 101 |
+
print(f"Array contents: {result}")
|
| 102 |
+
|
| 103 |
+
if mlb and hasattr(mlb, 'classes_'):
|
| 104 |
+
try:
|
| 105 |
+
# Check if it's a binary array
|
| 106 |
+
if len(result.shape) == 2: # First dim is samples, second is classes
|
| 107 |
+
labels = mlb.classes_[result[0].astype(bool)].tolist()
|
| 108 |
+
print(f"Converted to labels: {labels}")
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Error converting to labels: {e}")
|
| 111 |
+
|
| 112 |
+
# If it's a list, check the first item
|
| 113 |
+
elif isinstance(result, list) and len(result) > 0:
|
| 114 |
+
print(f"First item type: {type(result[0])}")
|
| 115 |
+
print(f"First item value: {result[0]}")
|
| 116 |
+
|
| 117 |
+
# If it's a dictionary, check its structure
|
| 118 |
+
elif isinstance(result, dict):
|
| 119 |
+
print("Dictionary keys:")
|
| 120 |
+
for key in result:
|
| 121 |
+
value = result[key]
|
| 122 |
+
print(f" - {key} ({type(value)})")
|
| 123 |
+
|
| 124 |
+
# Show a sample of the value
|
| 125 |
+
if isinstance(value, (list, tuple)) and len(value) > 0:
|
| 126 |
+
print(f" Sample: {value[:3]}...")
|
| 127 |
+
elif isinstance(value, dict) and len(value) > 0:
|
| 128 |
+
sample_keys = list(value.keys())[:3]
|
| 129 |
+
print(f" Sample keys: {sample_keys}...")
|
| 130 |
+
else:
|
| 131 |
+
print(f" Value: {value}")
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Error during prediction: {e}")
|
| 135 |
+
print(traceback.format_exc())
|
| 136 |
+
|
| 137 |
+
print("\nDebugging complete")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Error loading or processing model: {e}")
|
| 141 |
+
print(traceback.format_exc())
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
model_path = r"C:\Users\tueyc\CMKL Year 1\nomad_sync_app\backend\hybrid_interest_classifier.pkl"
|
| 145 |
+
test_text = "I hike mountains and explore cultures while traveling. I also love cooking new recipes."
|
| 146 |
+
|
| 147 |
+
if len(sys.argv) > 1:
|
| 148 |
+
model_path = sys.argv[1]
|
| 149 |
+
if len(sys.argv) > 2:
|
| 150 |
+
test_text = sys.argv[2]
|
| 151 |
+
|
| 152 |
+
debug_model(model_path, test_text)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.109.2
|
| 2 |
+
uvicorn==0.24.0
|
| 3 |
+
pydantic==2.5.2
|
| 4 |
+
scikit-learn==1.3.0
|
| 5 |
+
numpy==1.25.2
|
| 6 |
+
scipy==1.11.3
|
| 7 |
+
pandas==2.1.0
|
| 8 |
+
torch==2.0.1
|
| 9 |
+
transformers==4.33.2
|
| 10 |
+
python-multipart==0.0.6
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.10.13
|
space.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: "fastapi"
|
| 2 |
+
app_file: "app.py"
|
survey_interest_dataset_enhanced.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
def preprocess_text(text):
|
| 5 |
+
"""
|
| 6 |
+
Enhanced text preprocessing that better preserves domain-specific indicators
|
| 7 |
+
"""
|
| 8 |
+
# Handle potential NaN values
|
| 9 |
+
if text is None or isinstance(text, float) and pd.isna(text):
|
| 10 |
+
return ""
|
| 11 |
+
|
| 12 |
+
# Convert to lowercase
|
| 13 |
+
text = text.lower()
|
| 14 |
+
|
| 15 |
+
# Remove special characters while preserving important separators
|
| 16 |
+
text = re.sub(r'[^\w\s|-]', ' ', text)
|
| 17 |
+
|
| 18 |
+
# Replace multiple spaces with a single space
|
| 19 |
+
text = re.sub(r'\s+', ' ', text)
|
| 20 |
+
|
| 21 |
+
# Explicitly preserve key domain terms by adding them multiple times
|
| 22 |
+
# This increases their weight in the vectorization
|
| 23 |
+
domain_terms = {
|
| 24 |
+
'music': ['music', 'guitar', 'band', 'concert', 'gig', 'sing', 'song', 'play music', 'musician'],
|
| 25 |
+
'food': ['food', 'cook', 'cuisine', 'recipe', 'restaurant', 'eat', 'culinary', 'bake', 'chef'],
|
| 26 |
+
'sports': ['sport', 'run', 'gym', 'fitness', 'workout', 'exercise', 'athletic', 'training'],
|
| 27 |
+
'arts': ['art', 'paint', 'draw', 'museum', 'gallery', 'exhibit', 'creative', 'design'],
|
| 28 |
+
'technology': ['tech', 'code', 'program', 'software', 'developer', 'computer', 'app', 'digital'],
|
| 29 |
+
'education': ['education', 'learn', 'course', 'class', 'study', 'book', 'read', 'academic'],
|
| 30 |
+
'travel': ['travel', 'trip', 'hike', 'explore', 'tour', 'visit', 'journey', 'destination']
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Check for domain terms and emphasize them
|
| 34 |
+
modified_text = text
|
| 35 |
+
for category, terms in domain_terms.items():
|
| 36 |
+
for term in terms:
|
| 37 |
+
if term in text:
|
| 38 |
+
# Add the category name explicitly if a related term is found
|
| 39 |
+
modified_text += f" {category} {category} {term} {term}"
|
| 40 |
+
|
| 41 |
+
# Split on common separators but preserve the important phrases
|
| 42 |
+
parts = []
|
| 43 |
+
for part in re.split(r'\s*\|\s*', modified_text):
|
| 44 |
+
# Remove numbers (but keep words with numbers like "web3")
|
| 45 |
+
part = re.sub(r'\b\d+\b', '', part)
|
| 46 |
+
parts.append(part)
|
| 47 |
+
|
| 48 |
+
# Define a more focused stopwords list (smaller to keep more domain indicators)
|
| 49 |
+
core_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'the', 'a', 'an', 'and', 'but',
|
| 50 |
+
'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with',
|
| 51 |
+
'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after',
|
| 52 |
+
'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over',
|
| 53 |
+
'under', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were'}
|
| 54 |
+
|
| 55 |
+
# Process each part and filter stopwords
|
| 56 |
+
processed_parts = []
|
| 57 |
+
for part in parts:
|
| 58 |
+
words = part.split()
|
| 59 |
+
filtered_words = [word for word in words if word not in core_stopwords]
|
| 60 |
+
|
| 61 |
+
if filtered_words:
|
| 62 |
+
processed_parts.append(' '.join(filtered_words))
|
| 63 |
+
|
| 64 |
+
# Join the processed parts back
|
| 65 |
+
processed_text = ' '.join(processed_parts)
|
| 66 |
+
|
| 67 |
+
return processed_text.strip()
|