BlakeL commited on
Commit
b91cdea
·
verified ·
1 Parent(s): 46ba869

Upload 11 files

Browse files
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/social_sphere_llm/__init__.py ADDED
File without changes
src/social_sphere_llm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (176 Bytes). View file
 
src/social_sphere_llm/__pycache__/api_service.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
src/social_sphere_llm/__pycache__/prediction_service.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
src/social_sphere_llm/__pycache__/unified_api_service.cpython-312.pyc ADDED
Binary file (17.9 kB). View file
 
src/social_sphere_llm/__pycache__/unified_prediction_service.cpython-312.pyc ADDED
Binary file (21.8 kB). View file
 
src/social_sphere_llm/api_service.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Social Media Analysis API Service
3
+
4
+ A FastAPI web service for serving MLflow-trained social media analysis models.
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ from typing import List, Dict, Optional, Any
11
+ import uvicorn
12
+ import json
13
+ import logging
14
+ from datetime import datetime
15
+ import pandas as pd
16
+
17
+ from .prediction_service import SocialMediaPredictionService
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI(
25
+ title="Social Media Analysis API",
26
+ description="API for predicting social media addiction using MLflow models",
27
+ version="1.0.0",
28
+ docs_url="/docs",
29
+ redoc_url="/redoc"
30
+ )
31
+
32
+ # Add CORS middleware
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ # Global prediction service
42
+ prediction_service = None
43
+
44
+
45
+ class PredictionRequest(BaseModel):
46
+ """Request model for single prediction."""
47
+ data: Dict[str, Any] = Field(..., description="Input features for prediction")
48
+
49
+ class Config:
50
+ schema_extra = {
51
+ "example": {
52
+ "data": {
53
+ "feature1": 0.5,
54
+ "feature2": -0.2,
55
+ "feature3": 1.0
56
+ }
57
+ }
58
+ }
59
+
60
+
61
+ class BatchPredictionRequest(BaseModel):
62
+ """Request model for batch predictions."""
63
+ data: List[Dict[str, Any]] = Field(..., description="List of input features for predictions")
64
+
65
+ class Config:
66
+ schema_extra = {
67
+ "example": {
68
+ "data": [
69
+ {"feature1": 0.5, "feature2": -0.2, "feature3": 1.0},
70
+ {"feature1": -0.1, "feature2": 0.8, "feature3": -0.5}
71
+ ]
72
+ }
73
+ }
74
+
75
+
76
+ class PredictionResponse(BaseModel):
77
+ """Response model for predictions."""
78
+ prediction: int = Field(..., description="Predicted class (0: Low Risk, 1: High Risk)")
79
+ probability: List[float] = Field(..., description="Class probabilities")
80
+ confidence: float = Field(..., description="Confidence score")
81
+ prediction_class: str = Field(..., description="Human-readable prediction class")
82
+ model_name: str = Field(..., description="Name of the model used")
83
+ model_version: str = Field(..., description="Version of the model used")
84
+ timestamp: str = Field(..., description="Prediction timestamp")
85
+
86
+
87
+ class BatchPredictionResponse(BaseModel):
88
+ """Response model for batch predictions."""
89
+ predictions: List[int] = Field(..., description="List of predicted classes")
90
+ probabilities: List[List[float]] = Field(..., description="List of class probabilities")
91
+ confidence_scores: List[float] = Field(..., description="List of confidence scores")
92
+ prediction_classes: List[str] = Field(..., description="List of human-readable prediction classes")
93
+ model_name: str = Field(..., description="Name of the model used")
94
+ model_version: str = Field(..., description="Version of the model used")
95
+ timestamp: str = Field(..., description="Prediction timestamp")
96
+ total_predictions: int = Field(..., description="Total number of predictions made")
97
+
98
+
99
+ class ModelInfoResponse(BaseModel):
100
+ """Response model for model information."""
101
+ model_name: str = Field(..., description="Name of the model")
102
+ model_version: str = Field(..., description="Version of the model")
103
+ model_loaded: bool = Field(..., description="Whether the model is loaded")
104
+ feature_columns: Optional[List[str]] = Field(None, description="Required feature columns")
105
+ model_type: Optional[str] = Field(None, description="Type of the model")
106
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Model metadata")
107
+
108
+
109
+ class HealthResponse(BaseModel):
110
+ """Response model for health check."""
111
+ status: str = Field(..., description="Service status")
112
+ timestamp: str = Field(..., description="Current timestamp")
113
+ model_loaded: bool = Field(..., description="Whether the model is loaded")
114
+ uptime: str = Field(..., description="Service uptime")
115
+
116
+
117
+ # Startup and shutdown events
118
+ @app.on_event("startup")
119
+ async def startup_event():
120
+ """Initialize the prediction service on startup."""
121
+ global prediction_service
122
+ try:
123
+ prediction_service = SocialMediaPredictionService()
124
+ logger.info("✅ Prediction service initialized successfully")
125
+ except Exception as e:
126
+ logger.error(f"❌ Failed to initialize prediction service: {e}")
127
+ prediction_service = None
128
+
129
+
130
+ @app.on_event("shutdown")
131
+ async def shutdown_event():
132
+ """Cleanup on shutdown."""
133
+ logger.info("🔄 Shutting down Social Media Analysis API")
134
+
135
+
136
+ # Health check endpoint
137
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
138
+ async def health_check():
139
+ """Check the health status of the API service."""
140
+ return HealthResponse(
141
+ status="healthy" if prediction_service and prediction_service.model else "unhealthy",
142
+ timestamp=datetime.now().isoformat(),
143
+ model_loaded=prediction_service is not None and prediction_service.model is not None,
144
+ uptime="running"
145
+ )
146
+
147
+
148
+ # Model information endpoint
149
+ @app.get("/model/info", response_model=ModelInfoResponse, tags=["Model"])
150
+ async def get_model_info():
151
+ """Get information about the loaded model."""
152
+ if not prediction_service:
153
+ raise HTTPException(status_code=503, detail="Prediction service not available")
154
+
155
+ try:
156
+ model_info = prediction_service.get_model_info()
157
+ return ModelInfoResponse(**model_info)
158
+ except Exception as e:
159
+ logger.error(f"❌ Failed to get model info: {e}")
160
+ raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
161
+
162
+
163
+ # Single prediction endpoint
164
+ @app.post("/predict", response_model=PredictionResponse, tags=["Prediction"])
165
+ async def predict_single(request: PredictionRequest):
166
+ """Make a prediction for a single data point."""
167
+ if not prediction_service:
168
+ raise HTTPException(status_code=503, detail="Prediction service not available")
169
+
170
+ try:
171
+ # Make prediction
172
+ result = prediction_service.predict_single(request.data)
173
+
174
+ # Add timestamp
175
+ result['timestamp'] = datetime.now().isoformat()
176
+
177
+ return PredictionResponse(**result)
178
+
179
+ except Exception as e:
180
+ logger.error(f"❌ Prediction failed: {e}")
181
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
182
+
183
+
184
+ # Batch prediction endpoint
185
+ @app.post("/predict/batch", response_model=BatchPredictionResponse, tags=["Prediction"])
186
+ async def predict_batch(request: BatchPredictionRequest):
187
+ """Make predictions for multiple data points."""
188
+ if not prediction_service:
189
+ raise HTTPException(status_code=503, detail="Prediction service not available")
190
+
191
+ try:
192
+ # Make batch predictions
193
+ results = prediction_service.predict(request.data)
194
+
195
+ # Add timestamp and total count
196
+ results['timestamp'] = datetime.now().isoformat()
197
+ results['total_predictions'] = len(results['predictions'])
198
+
199
+ return BatchPredictionResponse(**results)
200
+
201
+ except Exception as e:
202
+ logger.error(f"❌ Batch prediction failed: {e}")
203
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
204
+
205
+
206
+ # Model reload endpoint
207
+ @app.post("/model/reload", tags=["Model"])
208
+ async def reload_model(background_tasks: BackgroundTasks):
209
+ """Reload the model in the background."""
210
+ if not prediction_service:
211
+ raise HTTPException(status_code=503, detail="Prediction service not available")
212
+
213
+ def reload_model_task():
214
+ """Background task to reload the model."""
215
+ global prediction_service
216
+ try:
217
+ prediction_service = SocialMediaPredictionService()
218
+ logger.info("✅ Model reloaded successfully")
219
+ except Exception as e:
220
+ logger.error(f"❌ Failed to reload model: {e}")
221
+
222
+ background_tasks.add_task(reload_model_task)
223
+
224
+ return {
225
+ "message": "Model reload initiated",
226
+ "timestamp": datetime.now().isoformat()
227
+ }
228
+
229
+
230
+ # Root endpoint
231
+ @app.get("/", tags=["Root"])
232
+ async def root():
233
+ """Root endpoint with API information."""
234
+ return {
235
+ "message": "Social Media Analysis API",
236
+ "version": "1.0.0",
237
+ "docs": "/docs",
238
+ "health": "/health",
239
+ "model_info": "/model/info",
240
+ "predict": "/predict",
241
+ "batch_predict": "/predict/batch"
242
+ }
243
+
244
+
245
+ # Error handlers
246
+ @app.exception_handler(404)
247
+ async def not_found_handler(request, exc):
248
+ """Handle 404 errors."""
249
+ return {
250
+ "error": "Not found",
251
+ "message": "The requested resource was not found",
252
+ "timestamp": datetime.now().isoformat()
253
+ }
254
+
255
+
256
+ @app.exception_handler(500)
257
+ async def internal_error_handler(request, exc):
258
+ """Handle 500 errors."""
259
+ return {
260
+ "error": "Internal server error",
261
+ "message": "An internal server error occurred",
262
+ "timestamp": datetime.now().isoformat()
263
+ }
264
+
265
+
266
+ def start_api_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
267
+ """
268
+ Start the FastAPI server.
269
+
270
+ Args:
271
+ host: Host to bind the server to
272
+ port: Port to bind the server to
273
+ reload: Whether to enable auto-reload
274
+ """
275
+ uvicorn.run(
276
+ "social_sphere_llm.api_service:app",
277
+ host=host,
278
+ port=port,
279
+ reload=reload,
280
+ log_level="info"
281
+ )
282
+
283
+
284
+ if __name__ == "__main__":
285
+ # Start the API server
286
+ print("🚀 Starting Social Media Analysis API...")
287
+ start_api_server(host="0.0.0.0", port=8000, reload=True)
src/social_sphere_llm/prediction_service.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Social Media Analysis Prediction Service
3
+
4
+ This module provides a production-ready service for making predictions
5
+ using MLflow-trained models for social media addiction analysis.
6
+ """
7
+
8
+ import mlflow
9
+ import pandas as pd
10
+ import numpy as np
11
+ import json
12
+ import logging
13
+ from typing import Dict, List, Union, Optional
14
+ from pathlib import Path
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SocialMediaPredictionService:
22
+ """
23
+ A service class for making predictions on social media data using MLflow models.
24
+ """
25
+
26
+ def __init__(self, model_name: str = "social_media_best_model", model_version: str = "latest"):
27
+ """
28
+ Initialize the prediction service.
29
+
30
+ Args:
31
+ model_name: Name of the registered MLflow model
32
+ model_version: Version of the model to load (default: "latest")
33
+ """
34
+ self.model_name = model_name
35
+ self.model_version = model_version
36
+ self.model = None
37
+ self.model_metadata = None
38
+ self.feature_columns = None
39
+
40
+ # Set MLflow tracking URI
41
+ mlflow.set_tracking_uri("file:./mlruns")
42
+
43
+ # Load the model
44
+ self._load_model()
45
+
46
+ def _load_model(self):
47
+ """Load the MLflow model and metadata."""
48
+ try:
49
+ # Load the model
50
+ model_uri = f"models:/{self.model_name}/{self.model_version}"
51
+ self.model = mlflow.sklearn.load_model(model_uri)
52
+ logger.info(f"✅ Model loaded successfully: {model_uri}")
53
+
54
+ # Try to load model metadata
55
+ self._load_metadata()
56
+
57
+ except Exception as e:
58
+ logger.error(f"❌ Failed to load model: {e}")
59
+ raise
60
+
61
+ def _load_metadata(self):
62
+ """Load model metadata if available."""
63
+ try:
64
+ # Look for metadata in the model artifacts
65
+ client = mlflow.tracking.MlflowClient()
66
+ model_versions = client.search_model_versions(f"name='{self.model_name}'")
67
+
68
+ if model_versions:
69
+ latest_version = max(model_versions, key=lambda x: x.version)
70
+ run_id = latest_version.run_id
71
+
72
+ # Try to load metadata from the run
73
+ run = client.get_run(run_id)
74
+ if run.data.artifacts:
75
+ # Look for metadata file
76
+ for artifact in run.data.artifacts:
77
+ if artifact.path.endswith('model_metadata.json'):
78
+ metadata_path = f"mlruns/{run.info.experiment_id}/{run_id}/artifacts/{artifact.path}"
79
+ if Path(metadata_path).exists():
80
+ with open(metadata_path, 'r') as f:
81
+ self.model_metadata = json.load(f)
82
+ self.feature_columns = self.model_metadata.get('feature_columns', [])
83
+ logger.info("✅ Model metadata loaded successfully")
84
+ break
85
+
86
+ except Exception as e:
87
+ logger.warning(f"⚠️ Could not load model metadata: {e}")
88
+
89
+ def preprocess_data(self, data: Union[pd.DataFrame, Dict, List[Dict]]) -> pd.DataFrame:
90
+ """
91
+ Preprocess input data to match the model's expected format.
92
+
93
+ Args:
94
+ data: Input data in various formats
95
+
96
+ Returns:
97
+ Preprocessed DataFrame
98
+ """
99
+ # Convert to DataFrame if needed
100
+ if isinstance(data, dict):
101
+ data = pd.DataFrame([data])
102
+ elif isinstance(data, list):
103
+ data = pd.DataFrame(data)
104
+ elif not isinstance(data, pd.DataFrame):
105
+ raise ValueError("Data must be a DataFrame, dict, or list of dicts")
106
+
107
+ # Make a copy to avoid modifying original data
108
+ df = data.copy()
109
+
110
+ # Handle missing columns
111
+ if self.feature_columns:
112
+ missing_cols = set(self.feature_columns) - set(df.columns)
113
+ if missing_cols:
114
+ logger.warning(f"⚠️ Missing columns: {missing_cols}")
115
+ # Fill missing columns with 0 or appropriate defaults
116
+ for col in missing_cols:
117
+ df[col] = 0
118
+
119
+ # Select only the required features
120
+ if self.feature_columns:
121
+ available_cols = [col for col in self.feature_columns if col in df.columns]
122
+ df = df[available_cols]
123
+
124
+ # Handle categorical variables (basic encoding)
125
+ categorical_cols = df.select_dtypes(include=['object', 'category']).columns
126
+ for col in categorical_cols:
127
+ if col in df.columns:
128
+ df[col] = df[col].astype(str).astype('category').cat.codes
129
+
130
+ # Fill missing values
131
+ df = df.fillna(0)
132
+
133
+ logger.info(f"✅ Data preprocessed: {df.shape}")
134
+ return df
135
+
136
+ def predict(self, data: Union[pd.DataFrame, Dict, List[Dict]]) -> Dict:
137
+ """
138
+ Make predictions on the input data.
139
+
140
+ Args:
141
+ data: Input data to predict on
142
+
143
+ Returns:
144
+ Dictionary containing prediction results
145
+ """
146
+ if self.model is None:
147
+ raise ValueError("Model not loaded. Please initialize the service properly.")
148
+
149
+ try:
150
+ # Preprocess the data
151
+ processed_data = self.preprocess_data(data)
152
+
153
+ # Make predictions
154
+ predictions = self.model.predict(processed_data)
155
+ probabilities = self.model.predict_proba(processed_data)
156
+
157
+ # Prepare results
158
+ results = {
159
+ 'predictions': predictions.tolist(),
160
+ 'probabilities': probabilities.tolist(),
161
+ 'model_name': self.model_name,
162
+ 'model_version': self.model_version,
163
+ 'confidence_scores': np.max(probabilities, axis=1).tolist(),
164
+ 'prediction_classes': ['Low Risk' if p == 0 else 'High Risk' for p in predictions],
165
+ 'data_shape': processed_data.shape
166
+ }
167
+
168
+ # Add metadata if available
169
+ if self.model_metadata:
170
+ results['model_metadata'] = {
171
+ 'training_date': self.model_metadata.get('training_date'),
172
+ 'model_type': self.model_metadata.get('model_type'),
173
+ 'performance_metrics': self.model_metadata.get('performance_metrics', {})
174
+ }
175
+
176
+ logger.info(f"✅ Predictions completed for {len(predictions)} samples")
177
+ return results
178
+
179
+ except Exception as e:
180
+ logger.error(f"❌ Prediction failed: {e}")
181
+ raise
182
+
183
+ def predict_single(self, data: Dict) -> Dict:
184
+ """
185
+ Make a prediction for a single data point.
186
+
187
+ Args:
188
+ data: Single data point as a dictionary
189
+
190
+ Returns:
191
+ Dictionary containing single prediction result
192
+ """
193
+ results = self.predict(data)
194
+
195
+ # Return single prediction result
196
+ return {
197
+ 'prediction': results['predictions'][0],
198
+ 'probability': results['probabilities'][0],
199
+ 'confidence': results['confidence_scores'][0],
200
+ 'prediction_class': results['prediction_classes'][0],
201
+ 'model_name': results['model_name'],
202
+ 'model_version': results['model_version']
203
+ }
204
+
205
+ def get_model_info(self) -> Dict:
206
+ """
207
+ Get information about the loaded model.
208
+
209
+ Returns:
210
+ Dictionary containing model information
211
+ """
212
+ info = {
213
+ 'model_name': self.model_name,
214
+ 'model_version': self.model_version,
215
+ 'model_loaded': self.model is not None,
216
+ 'feature_columns': self.feature_columns,
217
+ 'model_type': type(self.model.named_steps['classifier']).__name__ if self.model else None
218
+ }
219
+
220
+ if self.model_metadata:
221
+ info['metadata'] = self.model_metadata
222
+
223
+ return info
224
+
225
+
226
+ def create_prediction_service(model_name: str = "social_media_best_model") -> SocialMediaPredictionService:
227
+ """
228
+ Factory function to create a prediction service.
229
+
230
+ Args:
231
+ model_name: Name of the MLflow model to load
232
+
233
+ Returns:
234
+ Initialized prediction service
235
+ """
236
+ return SocialMediaPredictionService(model_name=model_name)
237
+
238
+
239
+ # Example usage and testing functions
240
+ def test_prediction_service():
241
+ """Test the prediction service with sample data."""
242
+ try:
243
+ # Create prediction service
244
+ service = create_prediction_service()
245
+
246
+ # Get model info
247
+ model_info = service.get_model_info()
248
+ print("📊 Model Information:")
249
+ print(json.dumps(model_info, indent=2))
250
+
251
+ # Create sample data (adjust based on your actual features)
252
+ sample_data = {
253
+ 'feature1': 0.5,
254
+ 'feature2': -0.2,
255
+ 'feature3': 1.0
256
+ }
257
+
258
+ # Make prediction
259
+ result = service.predict_single(sample_data)
260
+ print("\n🎯 Prediction Result:")
261
+ print(json.dumps(result, indent=2))
262
+
263
+ return True
264
+
265
+ except Exception as e:
266
+ print(f"❌ Test failed: {e}")
267
+ return False
268
+
269
+
270
+ if __name__ == "__main__":
271
+ # Run test if script is executed directly
272
+ print("🧪 Testing Social Media Prediction Service...")
273
+ success = test_prediction_service()
274
+
275
+ if success:
276
+ print("✅ Prediction service test completed successfully!")
277
+ else:
278
+ print("❌ Prediction service test failed!")
src/social_sphere_llm/unified_api_service.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Social Media Analysis API Service
3
+
4
+ A FastAPI web service for serving all three MLflow-trained social media analysis models:
5
+ 1. Conflicts Prediction (Notebook 07)
6
+ 2. Addicted Score Regression (Notebook 08)
7
+ 3. Clustering Analysis (Notebook 09)
8
+ """
9
+
10
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel, Field
13
+ from typing import List, Dict, Optional, Any
14
+ import uvicorn
15
+ import json
16
+ import logging
17
+ from datetime import datetime
18
+ import pandas as pd
19
+
20
+ from .unified_prediction_service import UnifiedSocialMediaPredictionService
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Initialize FastAPI app
27
+ app = FastAPI(
28
+ title="Unified Social Media Analysis API",
29
+ description="API for predicting social media addiction, conflicts, and clustering using MLflow models",
30
+ version="2.0.0",
31
+ docs_url="/docs",
32
+ redoc_url="/redoc"
33
+ )
34
+
35
+ # Add CORS middleware
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Global prediction service
45
+ prediction_service = None
46
+
47
+
48
+ class StudentDataRequest(BaseModel):
49
+ """Request model for student data."""
50
+ age: int = Field(..., ge=10, le=100, description="Student age")
51
+ gender: str = Field(..., description="Student gender (Male/Female)")
52
+ academic_level: str = Field(..., description="Academic level (High School/Undergraduate/Graduate)")
53
+ avg_daily_usage_hours: float = Field(..., ge=0, le=24, description="Average daily social media usage hours")
54
+ sleep_hours_per_night: float = Field(..., ge=0, le=24, description="Sleep hours per night")
55
+ mental_health_score: int = Field(..., ge=1, le=10, description="Mental health score (1-10)")
56
+ conflicts_over_social_media: int = Field(..., ge=0, le=10, description="Number of conflicts over social media")
57
+ addicted_score: int = Field(..., ge=1, le=10, description="Addiction score (1-10)")
58
+ relationship_status: str = Field(..., description="Relationship status")
59
+ affects_academic_performance: str = Field(..., description="Whether social media affects academic performance")
60
+ most_used_platform: str = Field(..., description="Most used social media platform")
61
+
62
+ class Config:
63
+ schema_extra = {
64
+ "example": {
65
+ "age": 20,
66
+ "gender": "Female",
67
+ "academic_level": "Undergraduate",
68
+ "avg_daily_usage_hours": 6.5,
69
+ "sleep_hours_per_night": 7.0,
70
+ "mental_health_score": 7,
71
+ "conflicts_over_social_media": 2,
72
+ "addicted_score": 6,
73
+ "relationship_status": "Single",
74
+ "affects_academic_performance": "Yes",
75
+ "most_used_platform": "Instagram"
76
+ }
77
+ }
78
+
79
+
80
+ class ConflictsPredictionResponse(BaseModel):
81
+ """Response model for conflicts predictions."""
82
+ predicted_conflicts: int = Field(..., description="Predicted conflicts (0: Low, 1: High)")
83
+ conflict_level: str = Field(..., description="Conflict risk level")
84
+ recommendation: str = Field(..., description="Intervention recommendation")
85
+ confidence: float = Field(..., description="Prediction confidence")
86
+ timestamp: str = Field(..., description="Prediction timestamp")
87
+ model_type: str = Field(..., description="Model type")
88
+
89
+
90
+ class AddictedScoreResponse(BaseModel):
91
+ """Response model for addicted score predictions."""
92
+ predicted_score: float = Field(..., description="Predicted addiction score")
93
+ addiction_level: str = Field(..., description="Addiction level category")
94
+ confidence: float = Field(..., description="Prediction confidence")
95
+ timestamp: str = Field(..., description="Prediction timestamp")
96
+ model_type: str = Field(..., description="Model type")
97
+
98
+
99
+ class ClusteringResponse(BaseModel):
100
+ """Response model for clustering predictions."""
101
+ cluster_id: int = Field(..., description="Assigned cluster ID")
102
+ cluster_label: str = Field(..., description="Cluster label")
103
+ risk_level: str = Field(..., description="Risk level")
104
+ recommendation: str = Field(..., description="Intervention recommendation")
105
+ confidence: float = Field(..., description="Prediction confidence")
106
+ timestamp: str = Field(..., description="Prediction timestamp")
107
+ model_type: str = Field(..., description="Model type")
108
+
109
+
110
+ class UnifiedPredictionResponse(BaseModel):
111
+ """Response model for unified predictions."""
112
+ conflicts_prediction: ConflictsPredictionResponse = Field(..., description="Conflicts prediction results")
113
+ addicted_score_prediction: AddictedScoreResponse = Field(..., description="Addicted score prediction results")
114
+ clustering_prediction: ClusteringResponse = Field(..., description="Clustering prediction results")
115
+ timestamp: str = Field(..., description="Prediction timestamp")
116
+ student_data: Dict[str, Any] = Field(..., description="Input student data")
117
+
118
+
119
+ class ModelStatusResponse(BaseModel):
120
+ """Response model for model status."""
121
+ conflicts_model_loaded: bool = Field(..., description="Whether conflicts model is loaded")
122
+ addicted_model_loaded: bool = Field(..., description="Whether addicted model is loaded")
123
+ clustering_model_loaded: bool = Field(..., description="Whether clustering model is loaded")
124
+ conflicts_scaler_loaded: bool = Field(..., description="Whether conflicts scaler is loaded")
125
+ addicted_scaler_loaded: bool = Field(..., description="Whether addicted scaler is loaded")
126
+ clustering_scaler_loaded: bool = Field(..., description="Whether clustering scaler is loaded")
127
+ cluster_labels_loaded: bool = Field(..., description="Whether cluster labels are loaded")
128
+ feature_names_loaded: bool = Field(..., description="Whether feature names are loaded")
129
+ timestamp: str = Field(..., description="Status timestamp")
130
+
131
+
132
+ class HealthResponse(BaseModel):
133
+ """Response model for health check."""
134
+ status: str = Field(..., description="Service status")
135
+ timestamp: str = Field(..., description="Current timestamp")
136
+ models_loaded: bool = Field(..., description="Whether all models are loaded")
137
+ uptime: str = Field(..., description="Service uptime")
138
+
139
+
140
+ # Startup and shutdown events
141
+ @app.on_event("startup")
142
+ async def startup_event():
143
+ """Initialize the unified prediction service on startup."""
144
+ global prediction_service
145
+ try:
146
+ prediction_service = UnifiedSocialMediaPredictionService()
147
+ logger.info("✅ Unified prediction service initialized successfully")
148
+ except Exception as e:
149
+ logger.error(f"❌ Failed to initialize unified prediction service: {e}")
150
+ prediction_service = None
151
+
152
+
153
+ @app.on_event("shutdown")
154
+ async def shutdown_event():
155
+ """Cleanup on shutdown."""
156
+ logger.info("🔄 Shutting down Unified Social Media Analysis API")
157
+
158
+
159
+ # Health check endpoint
160
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
161
+ async def health_check():
162
+ """Check the health status of the API service."""
163
+ models_loaded = (
164
+ prediction_service and
165
+ prediction_service.conflicts_model and
166
+ prediction_service.addicted_model and
167
+ prediction_service.clustering_model
168
+ )
169
+
170
+ return HealthResponse(
171
+ status="healthy" if models_loaded else "unhealthy",
172
+ timestamp=datetime.now().isoformat(),
173
+ models_loaded=models_loaded,
174
+ uptime="running"
175
+ )
176
+
177
+
178
+ # Model status endpoint
179
+ @app.get("/models/status", response_model=ModelStatusResponse, tags=["Models"])
180
+ async def get_model_status():
181
+ """Get status of all models."""
182
+ if not prediction_service:
183
+ raise HTTPException(status_code=503, detail="Prediction service not available")
184
+
185
+ try:
186
+ status = prediction_service.get_model_status()
187
+ return ModelStatusResponse(**status)
188
+ except Exception as e:
189
+ logger.error(f"❌ Failed to get model status: {e}")
190
+ raise HTTPException(status_code=500, detail=f"Failed to get model status: {str(e)}")
191
+
192
+
193
+ # Conflicts prediction endpoint
194
+ @app.post("/predict/conflicts", response_model=ConflictsPredictionResponse, tags=["Predictions"])
195
+ async def predict_conflicts(request: StudentDataRequest):
196
+ """Make a conflicts prediction for student data."""
197
+ if not prediction_service:
198
+ raise HTTPException(status_code=503, detail="Prediction service not available")
199
+
200
+ try:
201
+ # Convert request to dictionary
202
+ data = request.dict()
203
+
204
+ # Make prediction
205
+ result = prediction_service.predict_conflicts(data)
206
+
207
+ if 'error' in result:
208
+ raise HTTPException(status_code=500, detail=result['error'])
209
+
210
+ return ConflictsPredictionResponse(**result)
211
+
212
+ except Exception as e:
213
+ logger.error(f"❌ Conflicts prediction failed: {e}")
214
+ raise HTTPException(status_code=500, detail=f"Conflicts prediction failed: {str(e)}")
215
+
216
+
217
+ # Addicted score prediction endpoint
218
+ @app.post("/predict/addicted-score", response_model=AddictedScoreResponse, tags=["Predictions"])
219
+ async def predict_addicted_score(request: StudentDataRequest):
220
+ """Make an addicted score prediction for student data."""
221
+ if not prediction_service:
222
+ raise HTTPException(status_code=503, detail="Prediction service not available")
223
+
224
+ try:
225
+ # Convert request to dictionary
226
+ data = request.dict()
227
+
228
+ # Make prediction
229
+ result = prediction_service.predict_addicted_score(data)
230
+
231
+ if 'error' in result:
232
+ raise HTTPException(status_code=500, detail=result['error'])
233
+
234
+ return AddictedScoreResponse(**result)
235
+
236
+ except Exception as e:
237
+ logger.error(f"❌ Addicted score prediction failed: {e}")
238
+ raise HTTPException(status_code=500, detail=f"Addicted score prediction failed: {str(e)}")
239
+
240
+
241
+ # Clustering prediction endpoint
242
+ @app.post("/predict/clustering", response_model=ClusteringResponse, tags=["Predictions"])
243
+ async def predict_clustering(request: StudentDataRequest):
244
+ """Make a clustering prediction for student data."""
245
+ if not prediction_service:
246
+ raise HTTPException(status_code=503, detail="Prediction service not available")
247
+
248
+ try:
249
+ # Convert request to dictionary
250
+ data = request.dict()
251
+
252
+ # Make prediction
253
+ result = prediction_service.predict_cluster(data)
254
+
255
+ if 'error' in result:
256
+ raise HTTPException(status_code=500, detail=result['error'])
257
+
258
+ return ClusteringResponse(**result)
259
+
260
+ except Exception as e:
261
+ logger.error(f"❌ Clustering prediction failed: {e}")
262
+ raise HTTPException(status_code=500, detail=f"Clustering prediction failed: {str(e)}")
263
+
264
+
265
+ # Unified prediction endpoint
266
+ @app.post("/predict/all", response_model=UnifiedPredictionResponse, tags=["Predictions"])
267
+ async def predict_all(request: StudentDataRequest):
268
+ """Make predictions using all three models."""
269
+ if not prediction_service:
270
+ raise HTTPException(status_code=503, detail="Prediction service not available")
271
+
272
+ try:
273
+ # Convert request to dictionary
274
+ data = request.dict()
275
+
276
+ # Make all predictions
277
+ results = prediction_service.predict_all(data)
278
+
279
+ # Check for errors in any prediction
280
+ for key, result in results.items():
281
+ if isinstance(result, dict) and 'error' in result:
282
+ raise HTTPException(status_code=500, detail=f"{key} failed: {result['error']}")
283
+
284
+ return UnifiedPredictionResponse(**results)
285
+
286
+ except Exception as e:
287
+ logger.error(f"❌ Unified prediction failed: {e}")
288
+ raise HTTPException(status_code=500, detail=f"Unified prediction failed: {str(e)}")
289
+
290
+
291
+ # Model reload endpoint
292
+ @app.post("/models/reload", tags=["Models"])
293
+ async def reload_models(background_tasks: BackgroundTasks):
294
+ """Reload all models in the background."""
295
+ if not prediction_service:
296
+ raise HTTPException(status_code=503, detail="Prediction service not available")
297
+
298
+ def reload_models_task():
299
+ """Background task to reload all models."""
300
+ global prediction_service
301
+ try:
302
+ prediction_service = UnifiedSocialMediaPredictionService()
303
+ logger.info("✅ All models reloaded successfully")
304
+ except Exception as e:
305
+ logger.error(f"❌ Failed to reload models: {e}")
306
+
307
+ background_tasks.add_task(reload_models_task)
308
+
309
+ return {
310
+ "message": "Model reload initiated",
311
+ "timestamp": datetime.now().isoformat()
312
+ }
313
+
314
+
315
+ # Root endpoint
316
+ @app.get("/", tags=["Root"])
317
+ async def root():
318
+ """Root endpoint with API information."""
319
+ return {
320
+ "message": "Unified Social Media Analysis API",
321
+ "version": "2.0.0",
322
+ "description": "API for predicting social media addiction, conflicts, and clustering",
323
+ "docs": "/docs",
324
+ "health": "/health",
325
+ "model_status": "/models/status",
326
+ "endpoints": {
327
+ "conflicts_prediction": "/predict/conflicts",
328
+ "addicted_score_prediction": "/predict/addicted-score",
329
+ "clustering_prediction": "/predict/clustering",
330
+ "unified_prediction": "/predict/all"
331
+ }
332
+ }
333
+
334
+
335
+ # Error handlers
336
+ @app.exception_handler(404)
337
+ async def not_found_handler(request, exc):
338
+ """Handle 404 errors."""
339
+ return {
340
+ "error": "Not found",
341
+ "message": "The requested endpoint does not exist",
342
+ "timestamp": datetime.now().isoformat()
343
+ }
344
+
345
+
346
+ @app.exception_handler(500)
347
+ async def internal_error_handler(request, exc):
348
+ """Handle 500 errors."""
349
+ return {
350
+ "error": "Internal server error",
351
+ "message": "An unexpected error occurred",
352
+ "timestamp": datetime.now().isoformat()
353
+ }
354
+
355
+
356
+ def start_unified_api_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
357
+ """
358
+ Start the unified API server.
359
+
360
+ Args:
361
+ host: Host to bind to
362
+ port: Port to bind to
363
+ reload: Whether to enable auto-reload
364
+ """
365
+ uvicorn.run(
366
+ "src.social_sphere_llm.unified_api_service:app",
367
+ host=host,
368
+ port=port,
369
+ reload=reload,
370
+ log_level="info"
371
+ )
372
+
373
+
374
+ if __name__ == "__main__":
375
+ start_unified_api_server()
src/social_sphere_llm/unified_prediction_service.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Social Media Analysis Prediction Service
3
+
4
+ This module provides a production-ready service for making predictions
5
+ using all three MLflow-trained models:
6
+ 1. Conflicts Prediction (Notebook 07)
7
+ 2. Addicted Score Regression (Notebook 08)
8
+ 3. Clustering Analysis (Notebook 09)
9
+ """
10
+
11
+ import mlflow
12
+ import pandas as pd
13
+ import numpy as np
14
+ import json
15
+ import logging
16
+ import joblib
17
+ from typing import Dict, List, Union, Optional
18
+ from pathlib import Path
19
+ from datetime import datetime
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class UnifiedSocialMediaPredictionService:
27
+ """
28
+ A unified service class for making predictions on social media data using all three models.
29
+ """
30
+
31
+ def __init__(self):
32
+ """
33
+ Initialize the unified prediction service with all three models.
34
+ """
35
+ self.conflicts_model = None
36
+ self.addicted_model = None
37
+ self.clustering_model = None
38
+ self.conflicts_scaler = None
39
+ self.addicted_scaler = None
40
+ self.clustering_scaler = None
41
+ self.cluster_labels = None
42
+ self.feature_names = {}
43
+
44
+ # Set MLflow tracking URI
45
+ mlflow.set_tracking_uri("file:./mlruns")
46
+
47
+ # Load all models
48
+ self._load_all_models()
49
+
50
+ def _load_all_models(self):
51
+ """Load all three models and their associated files."""
52
+ try:
53
+ # Load Conflicts Prediction Model (Notebook 07)
54
+ self._load_conflicts_model()
55
+
56
+ # Load Addicted Score Model (Notebook 08)
57
+ self._load_addicted_model()
58
+
59
+ # Load Clustering Model (Notebook 09)
60
+ self._load_clustering_model()
61
+
62
+ logger.info("✅ All models loaded successfully!")
63
+
64
+ except Exception as e:
65
+ logger.error(f"❌ Failed to load models: {e}")
66
+ raise
67
+
68
+ def _load_conflicts_model(self):
69
+ """Load the conflicts prediction model from Notebook 07."""
70
+ try:
71
+ # Try to load from different paths
72
+ model_paths = [
73
+ 'models/conflicts_classifier_rf.joblib',
74
+ '../models/conflicts_classifier_rf.joblib',
75
+ 'notebooks/models/conflicts_classifier_rf.joblib'
76
+ ]
77
+
78
+ for path in model_paths:
79
+ try:
80
+ self.conflicts_model = joblib.load(path)
81
+ logger.info(f"✅ Loaded conflicts model from: {path}")
82
+ break
83
+ except:
84
+ continue
85
+
86
+ # Load scaler
87
+ scaler_paths = [
88
+ 'models/conflicts_scaler.joblib',
89
+ '../models/conflicts_scaler.joblib',
90
+ 'notebooks/models/conflicts_scaler.joblib'
91
+ ]
92
+
93
+ for path in scaler_paths:
94
+ try:
95
+ self.conflicts_scaler = joblib.load(path)
96
+ logger.info(f"✅ Loaded conflicts scaler from: {path}")
97
+ break
98
+ except:
99
+ continue
100
+
101
+ # Load feature names
102
+ feature_paths = [
103
+ 'models/conflicts_feature_names.joblib',
104
+ '../models/conflicts_feature_names.joblib',
105
+ 'notebooks/models/conflicts_feature_names.joblib'
106
+ ]
107
+
108
+ for path in feature_paths:
109
+ try:
110
+ self.feature_names['conflicts'] = joblib.load(path)
111
+ logger.info(f"✅ Loaded conflicts feature names from: {path}")
112
+ break
113
+ except:
114
+ continue
115
+
116
+ except Exception as e:
117
+ logger.warning(f"⚠️ Could not load conflicts model: {e}")
118
+
119
+ def _load_addicted_model(self):
120
+ """Load the addicted score regression model from Notebook 08."""
121
+ try:
122
+ # Try to load from MLflow first
123
+ try:
124
+ model_uri = "models:/addicted_score_regressor/latest"
125
+ self.addicted_model = mlflow.sklearn.load_model(model_uri)
126
+ logger.info(f"✅ Loaded addicted model from MLflow: {model_uri}")
127
+ except:
128
+ # Try local paths
129
+ model_paths = [
130
+ 'models/addicted_score_model.joblib',
131
+ '../models/addicted_score_model.joblib',
132
+ 'notebooks/models/addicted_score_model.joblib'
133
+ ]
134
+
135
+ for path in model_paths:
136
+ try:
137
+ self.addicted_model = joblib.load(path)
138
+ logger.info(f"✅ Loaded addicted model from: {path}")
139
+ break
140
+ except:
141
+ continue
142
+
143
+ # Load scaler
144
+ scaler_paths = [
145
+ 'models/addicted_score_scaler.joblib',
146
+ '../models/addicted_score_scaler.joblib',
147
+ 'notebooks/models/addicted_score_scaler.joblib'
148
+ ]
149
+
150
+ for path in scaler_paths:
151
+ try:
152
+ self.addicted_scaler = joblib.load(path)
153
+ logger.info(f"✅ Loaded addicted scaler from: {path}")
154
+ break
155
+ except:
156
+ continue
157
+
158
+ except Exception as e:
159
+ logger.warning(f"⚠️ Could not load addicted model: {e}")
160
+
161
+ def _load_clustering_model(self):
162
+ """Load the clustering model from Notebook 09."""
163
+ try:
164
+ # Try to load from different paths
165
+ model_paths = [
166
+ 'models/clustering_model.joblib',
167
+ '../models/clustering_model.joblib',
168
+ 'notebooks/models/clustering_model.joblib'
169
+ ]
170
+
171
+ for path in model_paths:
172
+ try:
173
+ self.clustering_model = joblib.load(path)
174
+ logger.info(f"✅ Loaded clustering model from: {path}")
175
+ break
176
+ except:
177
+ continue
178
+
179
+ # Load scaler
180
+ scaler_paths = [
181
+ 'models/clustering_scaler.joblib',
182
+ '../models/clustering_scaler.joblib',
183
+ 'notebooks/models/clustering_scaler.joblib'
184
+ ]
185
+
186
+ for path in scaler_paths:
187
+ try:
188
+ self.clustering_scaler = joblib.load(path)
189
+ logger.info(f"✅ Loaded clustering scaler from: {path}")
190
+ break
191
+ except:
192
+ continue
193
+
194
+ # Load cluster labels
195
+ labels_paths = [
196
+ 'models/cluster_labels.joblib',
197
+ '../models/cluster_labels.joblib',
198
+ 'notebooks/models/cluster_labels.joblib'
199
+ ]
200
+
201
+ for path in labels_paths:
202
+ try:
203
+ self.cluster_labels = joblib.load(path)
204
+ logger.info(f"✅ Loaded cluster labels from: {path}")
205
+ break
206
+ except:
207
+ continue
208
+
209
+ # Load feature names
210
+ feature_paths = [
211
+ 'models/clustering_feature_names.joblib',
212
+ '../models/clustering_feature_names.joblib',
213
+ 'notebooks/models/clustering_feature_names.joblib'
214
+ ]
215
+
216
+ for path in feature_paths:
217
+ try:
218
+ self.feature_names['clustering'] = joblib.load(path)
219
+ logger.info(f"✅ Loaded clustering feature names from: {path}")
220
+ break
221
+ except:
222
+ continue
223
+
224
+ except Exception as e:
225
+ logger.warning(f"⚠️ Could not load clustering model: {e}")
226
+
227
+ def predict_conflicts(self, data: Dict) -> Dict:
228
+ """
229
+ Predict conflicts over social media using Notebook 07 model.
230
+
231
+ Args:
232
+ data: Dictionary containing student data
233
+
234
+ Returns:
235
+ Dictionary containing conflicts prediction results
236
+ """
237
+ if self.conflicts_model is None or self.conflicts_scaler is None:
238
+ return {
239
+ "error": "Conflicts model not loaded. Please run notebook 07 first.",
240
+ "timestamp": datetime.now().isoformat()
241
+ }
242
+
243
+ try:
244
+ # Prepare features for conflicts model (only 4 features needed)
245
+ features = {}
246
+
247
+ # Extract required features for conflicts model
248
+ if 'Mental_Health_Score' in data:
249
+ features['Mental_Health_Score'] = float(data['Mental_Health_Score'])
250
+ if 'Age' in data:
251
+ features['Age'] = float(data['Age'])
252
+
253
+ # Handle gender encoding
254
+ if 'Gender' in data:
255
+ gender = data['Gender'].lower()
256
+ if gender in ['male', 'm']:
257
+ features['Gender_Male'] = 1
258
+ features['Gender_Female'] = 0
259
+ elif gender in ['female', 'f']:
260
+ features['Gender_Male'] = 0
261
+ features['Gender_Female'] = 1
262
+ else:
263
+ features['Gender_Male'] = 0
264
+ features['Gender_Female'] = 0
265
+
266
+ # Create feature vector for scaler (2 features)
267
+ scaler_features = ['Mental_Health_Score', 'Age']
268
+ feature_vector = []
269
+ for feature in scaler_features:
270
+ if feature in features:
271
+ feature_vector.append(features[feature])
272
+ else:
273
+ feature_vector.append(0)
274
+
275
+ # Scale the features
276
+ feature_vector_scaled = self.conflicts_scaler.transform([feature_vector])
277
+
278
+ # Create full feature vector for model (4 features)
279
+ model_features = ['Mental_Health_Score', 'Age', 'Gender_Female', 'Gender_Male']
280
+ full_feature_vector = []
281
+ for feature in model_features:
282
+ if feature in features:
283
+ full_feature_vector.append(features[feature])
284
+ else:
285
+ full_feature_vector.append(0)
286
+
287
+ # Combine scaled features with categorical features
288
+ final_vector = list(feature_vector_scaled[0]) + full_feature_vector[2:] # Use scaled first 2, raw last 2
289
+
290
+ # Make prediction
291
+ prediction = self.conflicts_model.predict([final_vector])[0]
292
+ probability = self.conflicts_model.predict_proba([final_vector])[0]
293
+
294
+ # Determine conflict level
295
+ if prediction == 1:
296
+ conflict_level = 'High Risk'
297
+ recommendation = 'Immediate intervention needed: Conflict resolution training, communication skills'
298
+ else:
299
+ conflict_level = 'Low Risk'
300
+ recommendation = 'Monitor and provide resources: Healthy communication guidelines'
301
+
302
+ # Calculate confidence
303
+ confidence = max(probability)
304
+
305
+ return {
306
+ 'predicted_conflicts': int(prediction),
307
+ 'conflict_level': conflict_level,
308
+ 'recommendation': recommendation,
309
+ 'confidence': float(confidence),
310
+ 'timestamp': datetime.now().isoformat(),
311
+ 'model_type': 'conflicts_prediction'
312
+ }
313
+
314
+ except Exception as e:
315
+ return {
316
+ 'error': str(e),
317
+ 'timestamp': datetime.now().isoformat()
318
+ }
319
+
320
+ def predict_addicted_score(self, data: Dict) -> Dict:
321
+ """
322
+ Predict addicted score using Notebook 08 model.
323
+
324
+ Args:
325
+ data: Dictionary containing student data
326
+
327
+ Returns:
328
+ Dictionary containing addicted score prediction results
329
+ """
330
+ if self.addicted_model is None or self.addicted_scaler is None:
331
+ return {
332
+ "error": "Addicted score model not loaded. Please run notebook 08 first.",
333
+ "timestamp": datetime.now().isoformat()
334
+ }
335
+
336
+ try:
337
+ # Prepare features for addicted score model (3 features needed)
338
+ features = {}
339
+
340
+ # Extract required features for addicted score model
341
+ if 'Age' in data:
342
+ features['Age'] = float(data['Age'])
343
+ if 'Mental_Health_Score' in data:
344
+ features['Mental_Health_Score'] = float(data['Mental_Health_Score'])
345
+ # Add squared feature
346
+ features['mental_health_squared'] = features['Mental_Health_Score'] ** 2
347
+ if 'Conflicts_Over_Social_Media' in data:
348
+ features['Conflicts_Over_Social_Media'] = float(data['Conflicts_Over_Social_Media'])
349
+
350
+ # Handle gender encoding
351
+ if 'Gender' in data:
352
+ gender = data['Gender'].lower()
353
+ if gender in ['male', 'm']:
354
+ features['Gender_Male'] = 1
355
+ features['Gender_Female'] = 0
356
+ elif gender in ['female', 'f']:
357
+ features['Gender_Male'] = 0
358
+ features['Gender_Female'] = 1
359
+ else:
360
+ features['Gender_Male'] = 0
361
+ features['Gender_Female'] = 0
362
+
363
+ # Create feature vector for scaler (3 features)
364
+ scaler_features = ['Mental_Health_Score', 'Age', 'Conflicts_Over_Social_Media']
365
+ feature_vector = []
366
+ for feature in scaler_features:
367
+ if feature in features:
368
+ feature_vector.append(features[feature])
369
+ else:
370
+ feature_vector.append(0)
371
+
372
+ # Scale the features
373
+ feature_vector_scaled = self.addicted_scaler.transform([feature_vector])
374
+
375
+ # Create full feature vector for model (6 features)
376
+ model_features = ['Mental_Health_Score', 'Age', 'Conflicts_Over_Social_Media', 'mental_health_squared', 'Gender_Female', 'Gender_Male']
377
+ full_feature_vector = []
378
+ for feature in model_features:
379
+ if feature in features:
380
+ full_feature_vector.append(features[feature])
381
+ else:
382
+ full_feature_vector.append(0)
383
+
384
+ # Combine scaled features with additional features
385
+ final_vector = list(feature_vector_scaled[0]) + full_feature_vector[3:] # Use scaled first 3, raw last 3
386
+
387
+ # Make prediction
388
+ prediction = self.addicted_model.predict([final_vector])[0]
389
+
390
+ # Determine addiction level
391
+ if prediction >= 8:
392
+ addiction_level = 'Very High'
393
+ elif prediction >= 6:
394
+ addiction_level = 'High'
395
+ elif prediction >= 4:
396
+ addiction_level = 'Moderate'
397
+ else:
398
+ addiction_level = 'Low'
399
+
400
+ # Calculate confidence (simplified)
401
+ confidence = 0.8 # Default confidence
402
+
403
+ return {
404
+ 'predicted_score': float(prediction),
405
+ 'addiction_level': addiction_level,
406
+ 'confidence': float(confidence),
407
+ 'timestamp': datetime.now().isoformat(),
408
+ 'model_type': 'addicted_score_regression'
409
+ }
410
+
411
+ except Exception as e:
412
+ return {
413
+ 'error': str(e),
414
+ 'timestamp': datetime.now().isoformat()
415
+ }
416
+
417
+ def predict_cluster(self, data: Dict) -> Dict:
418
+ """
419
+ Predict cluster assignment using Notebook 09 model.
420
+
421
+ Args:
422
+ data: Dictionary containing student data
423
+
424
+ Returns:
425
+ Dictionary containing cluster prediction results
426
+ """
427
+ if self.clustering_model is None or self.clustering_scaler is None:
428
+ return {
429
+ "error": "Clustering model not loaded. Please run notebook 09 first.",
430
+ "timestamp": datetime.now().isoformat()
431
+ }
432
+
433
+ try:
434
+ # Prepare features
435
+ features = {}
436
+
437
+ # Extract numeric features
438
+ if 'Age' in data:
439
+ features['Age'] = float(data['Age'])
440
+ if 'Avg_Daily_Usage_Hours' in data:
441
+ features['Avg_Daily_Usage_Hours'] = float(data['Avg_Daily_Usage_Hours'])
442
+ if 'Sleep_Hours_Per_Night' in data:
443
+ features['Sleep_Hours_Per_Night'] = float(data['Sleep_Hours_Per_Night'])
444
+ if 'Mental_Health_Score' in data:
445
+ features['Mental_Health_Score'] = float(data['Mental_Health_Score'])
446
+ if 'Conflicts_Over_Social_Media' in data:
447
+ features['Conflicts_Over_Social_Media'] = float(data['Conflicts_Over_Social_Media'])
448
+ if 'Addicted_Score' in data:
449
+ features['Addicted_Score'] = float(data['Addicted_Score'])
450
+
451
+ # Handle categorical features
452
+ if 'Gender' in data:
453
+ gender = data['Gender'].lower()
454
+ if gender in ['male', 'm']:
455
+ features['Is_Female'] = 0
456
+ elif gender in ['female', 'f']:
457
+ features['Is_Female'] = 1
458
+ else:
459
+ features['Is_Female'] = 0
460
+
461
+ if 'Academic_Level' in data:
462
+ level = data['Academic_Level'].lower()
463
+ if 'undergraduate' in level:
464
+ features['Is_Undergraduate'] = 1
465
+ features['Is_Graduate'] = 0
466
+ features['Is_High_School'] = 0
467
+ elif 'graduate' in level:
468
+ features['Is_Undergraduate'] = 0
469
+ features['Is_Graduate'] = 1
470
+ features['Is_High_School'] = 0
471
+ elif 'high school' in level:
472
+ features['Is_Undergraduate'] = 0
473
+ features['Is_Graduate'] = 0
474
+ features['Is_High_School'] = 1
475
+ else:
476
+ features['Is_Undergraduate'] = 0
477
+ features['Is_Graduate'] = 0
478
+ features['Is_High_School'] = 0
479
+
480
+ # Create behavioral features
481
+ if 'Avg_Daily_Usage_Hours' in features:
482
+ features['High_Usage'] = 1 if features['Avg_Daily_Usage_Hours'] >= 6 else 0
483
+ if 'Sleep_Hours_Per_Night' in features:
484
+ features['Low_Sleep'] = 1 if features['Sleep_Hours_Per_Night'] <= 6 else 0
485
+ if 'Mental_Health_Score' in features:
486
+ features['Poor_Mental_Health'] = 1 if features['Mental_Health_Score'] <= 5 else 0
487
+ if 'Conflicts_Over_Social_Media' in features:
488
+ features['High_Conflict'] = 1 if features['Conflicts_Over_Social_Media'] >= 3 else 0
489
+ if 'Addicted_Score' in features:
490
+ features['High_Addiction'] = 1 if features['Addicted_Score'] >= 7 else 0
491
+
492
+ # Create interaction features
493
+ if 'Avg_Daily_Usage_Hours' in features and 'Sleep_Hours_Per_Night' in features:
494
+ features['Usage_Sleep_Ratio'] = features['Avg_Daily_Usage_Hours'] / features['Sleep_Hours_Per_Night']
495
+ if 'Mental_Health_Score' in features and 'Avg_Daily_Usage_Hours' in features:
496
+ features['Mental_Health_Usage_Ratio'] = features['Mental_Health_Score'] / features['Avg_Daily_Usage_Hours']
497
+
498
+ # Create feature vector in the correct order
499
+ feature_vector = []
500
+ for feature in self.feature_names.get('clustering', []):
501
+ if feature in features:
502
+ feature_vector.append(features[feature])
503
+ else:
504
+ feature_vector.append(0)
505
+
506
+ # Scale the features
507
+ feature_vector_scaled = self.clustering_scaler.transform([feature_vector])
508
+
509
+ # Make prediction
510
+ cluster_prediction = self.clustering_model.predict(feature_vector_scaled)[0]
511
+
512
+ # Get cluster label
513
+ cluster_label = self.cluster_labels.get(cluster_prediction, f'Cluster_{cluster_prediction}') if self.cluster_labels else f'Cluster_{cluster_prediction}'
514
+
515
+ # Determine risk level based on cluster characteristics
516
+ if 'High-Usage' in cluster_label and 'High-Addiction' in cluster_label:
517
+ risk_level = 'High Risk'
518
+ recommendation = 'Intensive intervention needed: Digital detox programs, counseling, parental monitoring'
519
+ elif 'High-Usage' in cluster_label or 'Poor-Health' in cluster_label:
520
+ risk_level = 'Moderate Risk'
521
+ recommendation = 'Targeted intervention recommended: Screen time limits, mental health support, sleep hygiene'
522
+ else:
523
+ risk_level = 'Low Risk'
524
+ recommendation = 'Monitor and provide resources: Educational materials, healthy usage guidelines'
525
+
526
+ # Calculate confidence based on distance to cluster center
527
+ try:
528
+ cluster_center = self.clustering_model.cluster_centers_[cluster_prediction]
529
+ distance = np.linalg.norm(feature_vector_scaled[0] - cluster_center)
530
+ confidence = max(0.1, 1 - distance/10) # Normalize distance to confidence
531
+ except:
532
+ confidence = 0.8 # Default confidence
533
+
534
+ return {
535
+ 'cluster_id': int(cluster_prediction),
536
+ 'cluster_label': cluster_label,
537
+ 'risk_level': risk_level,
538
+ 'recommendation': recommendation,
539
+ 'confidence': float(confidence),
540
+ 'timestamp': datetime.now().isoformat(),
541
+ 'model_type': 'clustering_analysis'
542
+ }
543
+
544
+ except Exception as e:
545
+ return {
546
+ 'error': str(e),
547
+ 'timestamp': datetime.now().isoformat()
548
+ }
549
+
550
+ def predict_all(self, data: Dict) -> Dict:
551
+ """
552
+ Make predictions using all three models.
553
+
554
+ Args:
555
+ data: Dictionary containing student data
556
+
557
+ Returns:
558
+ Dictionary containing all prediction results
559
+ """
560
+ results = {
561
+ 'conflicts_prediction': self.predict_conflicts(data),
562
+ 'addicted_score_prediction': self.predict_addicted_score(data),
563
+ 'clustering_prediction': self.predict_cluster(data),
564
+ 'timestamp': datetime.now().isoformat(),
565
+ 'student_data': data
566
+ }
567
+
568
+ return results
569
+
570
+ def get_model_status(self) -> Dict:
571
+ """
572
+ Get status of all models.
573
+
574
+ Returns:
575
+ Dictionary containing model status information
576
+ """
577
+ return {
578
+ 'conflicts_model_loaded': self.conflicts_model is not None,
579
+ 'addicted_model_loaded': self.addicted_model is not None,
580
+ 'clustering_model_loaded': self.clustering_model is not None,
581
+ 'conflicts_scaler_loaded': self.conflicts_scaler is not None,
582
+ 'addicted_scaler_loaded': self.addicted_scaler is not None,
583
+ 'clustering_scaler_loaded': self.clustering_scaler is not None,
584
+ 'cluster_labels_loaded': self.cluster_labels is not None,
585
+ 'feature_names_loaded': len(self.feature_names) > 0,
586
+ 'timestamp': datetime.now().isoformat()
587
+ }
588
+
589
+
590
+ def create_unified_prediction_service() -> UnifiedSocialMediaPredictionService:
591
+ """
592
+ Factory function to create a unified prediction service.
593
+
594
+ Returns:
595
+ Initialized unified prediction service
596
+ """
597
+ return UnifiedSocialMediaPredictionService()
598
+
599
+
600
+ # Example usage and testing functions
601
+ def test_unified_prediction_service():
602
+ """Test the unified prediction service with sample data."""
603
+ try:
604
+ # Create prediction service
605
+ service = create_unified_prediction_service()
606
+
607
+ # Get model status
608
+ status = service.get_model_status()
609
+ print("📊 Model Status:")
610
+ print(json.dumps(status, indent=2))
611
+
612
+ # Test with sample data
613
+ sample_data = {
614
+ 'Age': 20,
615
+ 'Gender': 'Female',
616
+ 'Academic_Level': 'Undergraduate',
617
+ 'Avg_Daily_Usage_Hours': 6.5,
618
+ 'Sleep_Hours_Per_Night': 7.0,
619
+ 'Mental_Health_Score': 7,
620
+ 'Conflicts_Over_Social_Media': 2,
621
+ 'Addicted_Score': 6,
622
+ 'Relationship_Status': 'Single',
623
+ 'Affects_Academic_Performance': 'Yes',
624
+ 'Most_Used_Platform': 'Instagram'
625
+ }
626
+
627
+ # Make all predictions
628
+ results = service.predict_all(sample_data)
629
+
630
+ print("\n📊 Unified Prediction Results:")
631
+ print(json.dumps(results, indent=2))
632
+
633
+ return results
634
+
635
+ except Exception as e:
636
+ print(f"❌ Test failed: {e}")
637
+ return None
638
+
639
+
640
+ if __name__ == "__main__":
641
+ test_unified_prediction_service()