Spaces:
Build error
Build error
| """ | |
| Gapura AI Analysis API | |
| FastAPI server for regression and NLP analysis of irregularity reports | |
| Uses real trained models from ai-model/models/ | |
| """ | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.gzip import GZipMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field, field_validator | |
| from pydantic_core import ValidationError | |
| from typing import List, Optional, Dict, Any, Tuple | |
| from collections import Counter | |
| import os | |
| import json | |
| import logging | |
| from datetime import datetime | |
| import numpy as np | |
| import pickle | |
| import pandas as pd | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data.cache_service import get_cache, CacheService | |
| from data.nlp_service import NLPModelService | |
| from data.shap_service import get_shap_explainer | |
| from data.anomaly_service import get_anomaly_detector | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| tags_metadata = [ | |
| { | |
| "name": "Analysis", | |
| "description": "Core AI analysis endpoints for irregularity reports.", | |
| }, | |
| { | |
| "name": "Health", | |
| "description": "System health and model status checks.", | |
| }, | |
| { | |
| "name": "Jobs", | |
| "description": "Asynchronous job management.", | |
| }, | |
| { | |
| "name": "Training", | |
| "description": "Model retraining and lifecycle management.", | |
| }, | |
| ] | |
| app = FastAPI( | |
| title="Gapura AI Analysis API", | |
| description=""" | |
| Gapura AI Analysis API provides advanced machine learning capabilities for analyzing irregularity reports. | |
| ## Features | |
| * **Regression Analysis**: Predict resolution time (days) based on report details. | |
| * **NLP Classification**: Determine severity (Critical, High, Medium, Low) and categorize issues. | |
| * **Entity Extraction**: Extract key entities like Airlines, Flight Numbers, and Dates. | |
| * **Summarization**: Generate executive summaries and key points from long reports. | |
| * **Trend Analysis**: Analyze trends by Airline, Hub, and Category. | |
| * **Anomaly Detection**: Identify unusual patterns in resolution times. | |
| ## Models | |
| * **Regression**: Random Forest Regressor (v1.0.0-trained) | |
| * **NLP**: Hybrid Transformer + Rule-based System (v4.0.0-onnx) | |
| """, | |
| version="2.1.0", | |
| openapi_tags=tags_metadata, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.add_middleware(GZipMiddleware, minimum_size=500) | |
| async def validation_exception_handler(request: Request, exc: ValidationError): | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "detail": "Validation error", | |
| "errors": exc.errors(), | |
| "body": exc.json(), | |
| }, | |
| ) | |
| # ============== Pydantic Models ============== | |
| from enum import Enum | |
| from datetime import date as date_type | |
| class ReportCategoryEnum(str, Enum): | |
| IRREGULARITY = "Irregularity" | |
| COMPLAINT = "Complaint" | |
| class AreaEnum(str, Enum): | |
| APRON = "Apron Area" | |
| TERMINAL = "Terminal Area" | |
| GENERAL = "General" | |
| class StatusEnum(str, Enum): | |
| OPEN = "Open" | |
| CLOSED = "Closed" | |
| IN_PROGRESS = "In Progress" | |
| class IrregularityReport(BaseModel): | |
| Date_of_Event: Optional[str] = Field(None, description="Date of the event") | |
| Airlines: Optional[str] = Field(None, max_length=100) | |
| Flight_Number: Optional[str] = Field(None, max_length=20) | |
| Branch: Optional[str] = Field(None, max_length=10) | |
| HUB: Optional[str] = Field(None, max_length=20) | |
| Route: Optional[str] = Field(None, max_length=50) | |
| Report_Category: Optional[str] = Field(None, max_length=50) | |
| Irregularity_Complain_Category: Optional[str] = Field(None, max_length=100) | |
| Report: Optional[str] = Field(None, max_length=2000) | |
| Root_Caused: Optional[str] = Field(None, max_length=2000) | |
| Action_Taken: Optional[str] = Field(None, max_length=2000) | |
| Area: Optional[str] = Field(None, max_length=50) | |
| Status: Optional[str] = Field(None, max_length=50) | |
| Reported_By: Optional[str] = Field(None, max_length=100) | |
| Upload_Irregularity_Photo: Optional[str] = Field(None) | |
| model_config = {"extra": "allow"} | |
| class AnalysisOptions(BaseModel): | |
| predictResolutionTime: bool = Field( | |
| default=True, description="Run regression model" | |
| ) | |
| classifySeverity: bool = Field( | |
| default=True, description="Classify severity using NLP" | |
| ) | |
| extractEntities: bool = Field( | |
| default=True, description="Extract entities using NER" | |
| ) | |
| generateSummary: bool = Field(default=True, description="Generate text summaries") | |
| analyzeTrends: bool = Field(default=True, description="Analyze trends") | |
| bypassCache: bool = Field( | |
| default=False, description="Bypass cache and fetch fresh data" | |
| ) | |
| includeRisk: bool = Field(default=False, description="Include risk assessment in analysis") | |
| class AnalysisRequest(BaseModel): | |
| sheetId: Optional[str] = Field(None, description="Google Sheet ID") | |
| sheetName: Optional[str] = Field(None, description="Sheet name (NON CARGO or CGO)") | |
| rowRange: Optional[str] = Field(None, description="Row range (e.g., A2:Z100)") | |
| data: Optional[List[IrregularityReport]] = Field( | |
| None, description="Direct data upload" | |
| ) | |
| options: AnalysisOptions = Field(default_factory=AnalysisOptions) | |
| def validate_data(cls, v): | |
| if v is not None and len(v) == 0: | |
| raise ValueError("data array cannot be empty") | |
| return v | |
| class ShapExplanation(BaseModel): | |
| baseValue: float = Field(description="Base/expected value from model") | |
| predictionExplained: bool = Field( | |
| description="Whether SHAP explanation is available" | |
| ) | |
| topFactors: List[Dict[str, Any]] = Field( | |
| default_factory=list, description="Top contributing features" | |
| ) | |
| explanation: str = Field(default="", description="Human-readable explanation") | |
| class AnomalyResult(BaseModel): | |
| isAnomaly: bool = Field(description="Whether prediction is anomalous") | |
| anomalyScore: float = Field(description="Anomaly score (0-1)") | |
| anomalies: List[Dict[str, Any]] = Field( | |
| default_factory=list, description="List of detected anomalies" | |
| ) | |
| class RegressionPrediction(BaseModel): | |
| reportId: str | |
| predictedDays: float | |
| confidenceInterval: Tuple[float, float] | |
| featureImportance: Dict[str, float] | |
| hasUnknownCategories: bool = Field( | |
| default=False, description="True if unknown categories were used in prediction" | |
| ) | |
| shapExplanation: Optional[ShapExplanation] = Field( | |
| default=None, description="SHAP-based explanation for prediction" | |
| ) | |
| anomalyDetection: Optional[AnomalyResult] = Field( | |
| default=None, description="Anomaly detection results" | |
| ) | |
| class RegressionResult(BaseModel): | |
| predictions: List[RegressionPrediction] | |
| modelMetrics: Dict[str, Any] | |
| class ClassificationResult(BaseModel): | |
| reportId: str | |
| severity: str | |
| severityConfidence: float | |
| areaType: str | |
| issueType: str | |
| issueTypeConfidence: float | |
| class Entity(BaseModel): | |
| text: str | |
| label: str | |
| start: int | |
| end: int | |
| confidence: float | |
| class EntityResult(BaseModel): | |
| reportId: str | |
| entities: List[Entity] | |
| class SummaryResult(BaseModel): | |
| reportId: str | |
| executiveSummary: str | |
| keyPoints: List[str] | |
| class SentimentResult(BaseModel): | |
| reportId: str | |
| urgencyScore: float | |
| sentiment: str | |
| keywords: List[str] | |
| class NLPResult(BaseModel): | |
| classifications: List[ClassificationResult] | |
| entities: List[EntityResult] | |
| summaries: List[SummaryResult] | |
| sentiment: List[SentimentResult] | |
| class TrendData(BaseModel): | |
| count: int | |
| avgResolutionDays: Optional[float] | |
| topIssues: List[str] | |
| class TrendResult(BaseModel): | |
| byAirline: Dict[str, TrendData] | |
| byHub: Dict[str, TrendData] | |
| byCategory: Dict[str, Dict[str, Any]] | |
| timeSeries: List[Dict[str, Any]] | |
| class Metadata(BaseModel): | |
| totalRecords: int | |
| processingTime: float | |
| modelVersions: Dict[str, str] | |
| class AnalysisResponse(BaseModel): | |
| regression: Optional[RegressionResult] = None | |
| nlp: Optional[NLPResult] = None | |
| trends: Optional[TrendResult] = None | |
| risk: Optional[RiskAssessmentResponse] = None | |
| metadata: Metadata | |
| class RiskItem(BaseModel): | |
| reportId: str | |
| severity: str | |
| severityConfidence: float | |
| predictedDays: float | |
| anomalyScore: float | |
| category: str | |
| hub: str | |
| area: str | |
| riskScore: float | |
| priority: str | |
| recommendedActions: List[Dict[str, Any]] = Field(default_factory=list) | |
| preventiveSuggestions: List[str] = Field(default_factory=list) | |
| class RiskAssessmentResponse(BaseModel): | |
| items: List[RiskItem] | |
| topPatterns: List[Dict[str, Any]] | |
| metadata: Dict[str, Any] | |
| def _severity_to_score(level: str) -> float: | |
| m = {"Critical": 1.0, "High": 0.8, "Medium": 0.5, "Low": 0.2} | |
| return m.get(level, 0.3) | |
| def _normalize_days(d: float) -> float: | |
| return max(0.0, min(1.0, float(d) / 7.0)) | |
| def _priority_from_score(s: float) -> str: | |
| if s >= 0.75: | |
| return "HIGH" | |
| if s >= 0.45: | |
| return "MEDIUM" | |
| return "LOW" | |
| def _extract_prevention(texts: List[str]) -> List[str]: | |
| kws = ["review", "prosedur", "procedure", "training", "pelatihan", "prevent", "pencegahan", "maintenance", "inspection", "inspeksi", "briefing", "supervision", "checklist", "verify", "verifikasi"] | |
| out = [] | |
| seen = set() | |
| for t in texts: | |
| lt = t.lower() | |
| for k in kws: | |
| if k in lt: | |
| if t not in seen: | |
| seen.add(t) | |
| out.append(t) | |
| return out[:5] | |
| # ============== Real Model Service ============== | |
| class ModelService: | |
| """Service that loads and uses real trained models""" | |
| def __init__(self): | |
| self.regression_version = "1.0.0-trained" | |
| self.nlp_version = "1.0.0-mock" | |
| self.regression_model = None | |
| self.regression_onnx_session = None | |
| self.label_encoders = {} | |
| self.scaler = None | |
| self.feature_names = [] | |
| self.model_metrics = {} | |
| self.model_loaded = False | |
| self.nlp_service = None | |
| self._load_regression_model() | |
| self._load_nlp_service() | |
| def _load_nlp_service(self): | |
| """Load NLP service with trained models or fallback""" | |
| try: | |
| from data.nlp_service import get_nlp_service | |
| self.nlp_service = get_nlp_service() | |
| self.nlp_version = self.nlp_service.version | |
| logger.info(f"NLP service loaded (version: {self.nlp_version})") | |
| except Exception as e: | |
| logger.warning(f"Failed to load NLP service: {e}") | |
| def _load_regression_model(self): | |
| """Load the trained regression model from file""" | |
| try: | |
| model_path = os.path.join( | |
| os.path.dirname(__file__), | |
| "..", | |
| "models", | |
| "regression", | |
| "resolution_predictor_latest.pkl", | |
| ) | |
| if not os.path.exists(model_path): | |
| logger.warning(f"Model file not found at {model_path}") | |
| return | |
| logger.info(f"Loading regression model from {model_path}") | |
| with open(model_path, "rb") as f: | |
| model_data = pickle.load(f) | |
| self.regression_model = model_data.get("model") | |
| self.label_encoders = model_data.get("label_encoders", {}) | |
| self.scaler = model_data.get("scaler") | |
| self.feature_names = model_data.get("feature_names", []) | |
| self.model_metrics = model_data.get("metrics", {}) | |
| self.model_loaded = True | |
| logger.info(f"✓ Regression model loaded successfully") | |
| logger.info(f" - Features: {len(self.feature_names)}") | |
| logger.info(f" - Metrics: MAE={self.model_metrics.get('test_mae', 'N/A')}") | |
| # Try to load ONNX model for faster inference | |
| onnx_path = os.path.join( | |
| os.path.dirname(__file__), | |
| "..", | |
| "models", | |
| "regression", | |
| "resolution_predictor.onnx", | |
| ) | |
| if os.path.exists(onnx_path): | |
| try: | |
| import onnxruntime as ort | |
| sess_options = ort.SessionOptions() | |
| sess_options.intra_op_num_threads = 1 | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| self.regression_onnx_session = ort.InferenceSession(onnx_path, sess_options) | |
| logger.info("✓ Regression ONNX model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to load Regression ONNX model: {e}") | |
| except Exception as e: | |
| logger.error(f"Failed to load regression model: {e}") | |
| self.model_loaded = False | |
| def _extract_features(self, report: Dict) -> Optional[np.ndarray]: | |
| """Extract features from a single report matching training preprocessing""" | |
| try: | |
| # Parse date | |
| date_str = report.get("Date_of_Event", "") | |
| try: | |
| date_obj = pd.to_datetime(date_str, errors="coerce") | |
| if pd.isna(date_obj): | |
| date_obj = datetime.now() | |
| day_of_week = date_obj.dayofweek | |
| month = date_obj.month | |
| is_weekend = day_of_week in [5, 6] | |
| week_of_year = date_obj.isocalendar().week | |
| day_of_year = date_obj.dayofyear | |
| except: | |
| day_of_week = 0 | |
| month = 1 | |
| is_weekend = False | |
| week_of_year = 1 | |
| day_of_year = 1 | |
| sin_day_of_week = np.sin(2 * np.pi * day_of_week / 7) | |
| cos_day_of_week = np.cos(2 * np.pi * day_of_week / 7) | |
| sin_month = np.sin(2 * np.pi * month / 12) | |
| cos_month = np.cos(2 * np.pi * month / 12) | |
| sin_day_of_year = np.sin(2 * np.pi * day_of_year / 365) | |
| cos_day_of_year = np.cos(2 * np.pi * day_of_year / 365) | |
| # Text features | |
| report_text = report.get("Report", "") | |
| root_cause = report.get("Root_Caused", "") | |
| action_taken = report.get("Action_Taken", "") | |
| # Categorical | |
| airline = report.get("Airlines", "Unknown") | |
| hub = report.get("HUB", "Unknown") | |
| branch = report.get("Branch", "Unknown") | |
| category = report.get("Irregularity_Complain_Category", "Unknown") | |
| area = report.get("Area", "Unknown") | |
| # Binary features | |
| has_photos = bool(report.get("Upload_Irregularity_Photo", "")) | |
| is_complaint = report.get("Report_Category", "") == "Complaint" | |
| # Encode categorical features | |
| categorical_values = { | |
| "airline": airline, | |
| "hub": hub, | |
| "branch": branch, | |
| "category": category, | |
| "area": area, | |
| } | |
| encoded_values = {} | |
| unknown_flags = {} | |
| for col, value in categorical_values.items(): | |
| if col in self.label_encoders: | |
| le = self.label_encoders[col] | |
| value_str = str(value) | |
| if value_str in le.classes_: | |
| encoded_values[f"{col}_encoded"] = le.transform([value_str])[0] | |
| unknown_flags[col] = False | |
| else: | |
| unknown_idx = ( | |
| le.transform(["Unknown"])[0] | |
| if "Unknown" in le.classes_ | |
| else 0 | |
| ) | |
| encoded_values[f"{col}_encoded"] = unknown_idx | |
| unknown_flags[col] = True | |
| logger.warning( | |
| f"Unknown {col} value: '{value_str}' - using Unknown category" | |
| ) | |
| else: | |
| encoded_values[f"{col}_encoded"] = 0 | |
| unknown_flags[col] = True | |
| # Build feature vector in correct order | |
| feature_dict = { | |
| "day_of_week": day_of_week, | |
| "month": month, | |
| "is_weekend": int(is_weekend), | |
| "week_of_year": week_of_year, | |
| "sin_day_of_week": sin_day_of_week, | |
| "cos_day_of_week": cos_day_of_week, | |
| "sin_month": sin_month, | |
| "cos_month": cos_month, | |
| "sin_day_of_year": sin_day_of_year, | |
| "cos_day_of_year": cos_day_of_year, | |
| "report_length": len(report_text), | |
| "report_word_count": len(report_text.split()) if report_text else 0, | |
| "root_cause_length": len(root_cause), | |
| "action_taken_length": len(action_taken), | |
| "has_photos": int(has_photos), | |
| "is_complaint": int(is_complaint), | |
| "text_complexity": (len(report_text) * len(report_text.split()) / 100) | |
| if report_text | |
| else 0, | |
| "has_root_cause": int(bool(root_cause)), | |
| "has_action_taken": int(bool(action_taken)), | |
| } | |
| feature_dict.update(encoded_values) | |
| has_unknown_categories = any(unknown_flags.values()) | |
| # Create feature array in correct order | |
| features = [] | |
| for feature_name in self.feature_names: | |
| features.append(feature_dict.get(feature_name, 0)) | |
| X = np.array([features]) | |
| # Scale features | |
| if self.scaler: | |
| X = self.scaler.transform(X) | |
| return X, has_unknown_categories | |
| except Exception as e: | |
| logger.error(f"Feature extraction error: {e}") | |
| return None, True | |
| def _extract_features_batch(self, df: pd.DataFrame) -> Tuple[Optional[np.ndarray], np.ndarray]: | |
| """Extract features from a dataframe matching training preprocessing (Batch optimized)""" | |
| try: | |
| # Ensure required columns exist | |
| required_cols = [ | |
| "Date_of_Event", "Report", "Root_Caused", "Action_Taken", | |
| "Upload_Irregularity_Photo", "Report_Category", | |
| "Airlines", "HUB", "Branch", "Irregularity_Complain_Category", "Area" | |
| ] | |
| for col in required_cols: | |
| if col not in df.columns: | |
| df[col] = None | |
| # Copy to avoid modifying original | |
| df = df.copy() | |
| # Parse date | |
| df["Date_of_Event"] = pd.to_datetime(df["Date_of_Event"], errors="coerce") | |
| now = datetime.now() | |
| df["Date_of_Event"] = df["Date_of_Event"].fillna(now) | |
| df["day_of_week"] = df["Date_of_Event"].dt.dayofweek | |
| df["month"] = df["Date_of_Event"].dt.month | |
| df["is_weekend"] = df["day_of_week"].isin([5, 6]).astype(int) | |
| df["week_of_year"] = df["Date_of_Event"].dt.isocalendar().week.astype(int) | |
| df["day_of_year"] = df["Date_of_Event"].dt.dayofyear | |
| # Sin/Cos transforms | |
| df["sin_day_of_week"] = np.sin(2 * np.pi * df["day_of_week"] / 7) | |
| df["cos_day_of_week"] = np.cos(2 * np.pi * df["day_of_week"] / 7) | |
| df["sin_month"] = np.sin(2 * np.pi * df["month"] / 12) | |
| df["cos_month"] = np.cos(2 * np.pi * df["month"] / 12) | |
| df["sin_day_of_year"] = np.sin(2 * np.pi * df["day_of_year"] / 365) | |
| df["cos_day_of_year"] = np.cos(2 * np.pi * df["day_of_year"] / 365) | |
| # Text features | |
| df["Report"] = df["Report"].fillna("").astype(str) | |
| df["Root_Caused"] = df["Root_Caused"].fillna("").astype(str) | |
| df["Action_Taken"] = df["Action_Taken"].fillna("").astype(str) | |
| df["report_length"] = df["Report"].str.len() | |
| df["report_word_count"] = df["Report"].apply(lambda x: len(x.split()) if x else 0) | |
| df["root_cause_length"] = df["Root_Caused"].str.len() | |
| df["action_taken_length"] = df["Action_Taken"].str.len() | |
| df["has_photos"] = df["Upload_Irregularity_Photo"].fillna("").astype(bool).astype(int) | |
| df["is_complaint"] = (df["Report_Category"] == "Complaint").astype(int) | |
| df["text_complexity"] = np.where( | |
| df["Report"].str.len() > 0, | |
| (df["report_length"] * df["report_word_count"] / 100), | |
| 0 | |
| ) | |
| df["has_root_cause"] = (df["Root_Caused"].str.len() > 0).astype(int) | |
| df["has_action_taken"] = (df["Action_Taken"].str.len() > 0).astype(int) | |
| # Categorical encoding | |
| categorical_cols = { | |
| "airline": "Airlines", | |
| "hub": "HUB", | |
| "branch": "Branch", | |
| "category": "Irregularity_Complain_Category", | |
| "area": "Area" | |
| } | |
| unknown_flags = np.zeros(len(df), dtype=bool) | |
| for feature_name, col_name in categorical_cols.items(): | |
| df[col_name] = df[col_name].fillna("Unknown").astype(str) | |
| if feature_name in self.label_encoders: | |
| le = self.label_encoders[feature_name] | |
| # Create mapping for fast lookup | |
| mapping = {label: idx for idx, label in enumerate(le.classes_)} | |
| unknown_idx = mapping.get("Unknown", 0) | |
| if "Unknown" in le.classes_: | |
| unknown_idx = mapping["Unknown"] | |
| # Map values | |
| encoded_col = df[col_name].map(mapping) | |
| # Track unknowns (NaN after map means unknown) | |
| is_unknown = encoded_col.isna() | |
| unknown_flags |= is_unknown.values | |
| # Fill unknowns | |
| df[f"{feature_name}_encoded"] = encoded_col.fillna(unknown_idx).astype(int) | |
| else: | |
| df[f"{feature_name}_encoded"] = 0 | |
| unknown_flags[:] = True | |
| # Select features in order | |
| for f in self.feature_names: | |
| if f not in df.columns: | |
| df[f] = 0 | |
| X = df[self.feature_names].values | |
| # Scale | |
| if self.scaler: | |
| X = self.scaler.transform(X) | |
| return X, unknown_flags | |
| except Exception as e: | |
| logger.error(f"Batch feature extraction error: {e}") | |
| return None, np.ones(len(df), dtype=bool) | |
| def predict_regression(self, data: List[Dict]) -> List[RegressionPrediction]: | |
| """Predict resolution time using trained model""" | |
| predictions = [] | |
| shap_explainer = get_shap_explainer() | |
| anomaly_detector = get_anomaly_detector() | |
| # Batch processing | |
| try: | |
| df = pd.DataFrame(data) | |
| X_batch, unknown_flags_batch = self._extract_features_batch(df) | |
| if X_batch is not None: | |
| if self.regression_onnx_session: | |
| # Use ONNX model | |
| input_name = self.regression_onnx_session.get_inputs()[0].name | |
| predicted_batch = self.regression_onnx_session.run(None, {input_name: X_batch.astype(np.float32)})[0] | |
| predicted_batch = predicted_batch.ravel() # Flatten to 1D array | |
| elif self.regression_model is not None: | |
| # Use Pickle model | |
| predicted_batch = self.regression_model.predict(X_batch) | |
| else: | |
| predicted_batch = None | |
| unknown_flags_batch = [True] * len(data) | |
| else: | |
| predicted_batch = None | |
| unknown_flags_batch = [True] * len(data) | |
| except Exception as e: | |
| logger.error(f"Batch prediction setup failed: {e}") | |
| predicted_batch = None | |
| unknown_flags_batch = [True] * len(data) | |
| for i, item in enumerate(data): | |
| # Use batch results | |
| has_unknown = unknown_flags_batch[i] | |
| features = X_batch[i:i+1] if X_batch is not None else None | |
| category = item.get("Irregularity_Complain_Category", "Unknown") | |
| hub = item.get("HUB", "Unknown") | |
| if predicted_batch is not None: | |
| predicted = predicted_batch[i] | |
| mae = self.model_metrics.get("test_mae", 0.5) | |
| lower = max(0.1, predicted - mae) | |
| upper = predicted + mae | |
| shap_exp = None | |
| if shap_explainer.explainer is not None and features is not None: | |
| try: | |
| shap_result = shap_explainer.explain_prediction(features) | |
| shap_exp = ShapExplanation( | |
| baseValue=shap_result.get("base_value", 0), | |
| predictionExplained=shap_result.get( | |
| "prediction_explained", False | |
| ), | |
| topFactors=shap_result.get("top_factors", [])[:5], | |
| explanation=shap_result.get("explanation", ""), | |
| ) | |
| except Exception as e: | |
| logger.debug(f"SHAP explanation failed: {e}") | |
| anomaly_result = None | |
| try: | |
| anomaly_data = anomaly_detector.detect_prediction_anomaly( | |
| predicted, category, hub | |
| ) | |
| anomaly_result = AnomalyResult( | |
| isAnomaly=anomaly_data.get("is_anomaly", False), | |
| anomalyScore=anomaly_data.get("anomaly_score", 0), | |
| anomalies=anomaly_data.get("anomalies", []), | |
| ) | |
| except Exception as e: | |
| logger.debug(f"Anomaly detection failed: {e}") | |
| else: | |
| base_days = { | |
| "Cargo Problems": 2.5, | |
| "Pax Handling": 1.8, | |
| "GSE": 3.2, | |
| "Operation": 2.1, | |
| "Baggage Handling": 1.5, | |
| }.get(category, 2.0) | |
| predicted = base_days + np.random.normal(0, 0.3) | |
| lower = max(0.1, predicted - 0.5) | |
| upper = predicted + 0.5 | |
| has_unknown = True | |
| shap_exp = None | |
| anomaly_result = None | |
| if self.model_metrics and "feature_importance" in self.model_metrics: | |
| importance = self.model_metrics["feature_importance"] | |
| else: | |
| importance = { | |
| "category": 0.35, | |
| "airline": 0.28, | |
| "hub": 0.15, | |
| "reportLength": 0.12, | |
| "hasPhotos": 0.10, | |
| } | |
| predictions.append( | |
| RegressionPrediction( | |
| reportId=f"row_{i}", | |
| predictedDays=round(max(0.1, predicted), 2), | |
| confidenceInterval=(round(lower, 2), round(upper, 2)), | |
| featureImportance=importance, | |
| hasUnknownCategories=has_unknown, | |
| shapExplanation=shap_exp, | |
| anomalyDetection=anomaly_result, | |
| ) | |
| ) | |
| return predictions | |
| def classify_text(self, data: List[Dict]) -> List[ClassificationResult]: | |
| """Classify text using trained NLP models or rule-based fallback""" | |
| results = [] | |
| texts = [ | |
| (item.get("Report") or "") + " " + (item.get("Root_Caused") or "") | |
| for item in data | |
| ] | |
| # Get multi-task predictions if available | |
| mt_results = None | |
| if self.nlp_service: | |
| mt_results = self.nlp_service.predict_multi_task(texts) | |
| severity_results = self.nlp_service.classify_severity(texts) | |
| else: | |
| severity_results = self._classify_severity_fallback(texts) | |
| for i, (item, sev_result) in enumerate(zip(data, severity_results)): | |
| severity = sev_result.get("severity", "Low") | |
| severity_conf = sev_result.get("confidence", 0.8) | |
| # Use multi-task predictions for area and issue type if available | |
| if mt_results and i < len(mt_results): | |
| mt_res = mt_results[i] | |
| area = mt_res.get("area", {}).get("label", item.get("Area", "Unknown")).replace(" Area", "") | |
| area_conf = mt_res.get("area", {}).get("confidence", 0.85) | |
| issue = mt_res.get("irregularity", {}).get("label", item.get("Irregularity_Complain_Category", "Unknown")) | |
| issue_conf = mt_res.get("irregularity", {}).get("confidence", 0.85) | |
| else: | |
| area = item.get("Area", "Unknown").replace(" Area", "") | |
| area_conf = 0.85 | |
| issue = item.get("Irregularity_Complain_Category", "Unknown") | |
| issue_conf = 0.85 | |
| results.append( | |
| ClassificationResult( | |
| reportId=f"row_{i}", | |
| severity=severity, | |
| severityConfidence=severity_conf, | |
| areaType=area, | |
| issueType=issue, | |
| issueTypeConfidence=issue_conf, | |
| ) | |
| ) | |
| return results | |
| def _classify_severity_fallback(self, texts: List[str]) -> List[Dict]: | |
| """Fallback severity classification""" | |
| results = [] | |
| for text in texts: | |
| report = text.lower() | |
| if any( | |
| kw in report | |
| for kw in ["damage", "torn", "broken", "critical", "emergency"] | |
| ): | |
| severity = "High" | |
| severity_conf = 0.89 | |
| elif any(kw in report for kw in ["delay", "late", "wrong", "error"]): | |
| severity = "Medium" | |
| severity_conf = 0.75 | |
| else: | |
| severity = "Low" | |
| severity_conf = 0.82 | |
| results.append({"severity": severity, "confidence": severity_conf}) | |
| return results | |
| def extract_entities(self, data: List[Dict]) -> List[EntityResult]: | |
| """Extract entities from reports""" | |
| results = [] | |
| for i, item in enumerate(data): | |
| entities = [] | |
| report_text = item.get("Report", "") + " " + item.get("Root_Caused", "") | |
| # Extract airline | |
| airline = item.get("Airlines", "") | |
| if airline and airline != "Unknown": | |
| # Find position in text | |
| idx = report_text.lower().find(airline.lower()) | |
| start = max(0, idx) if idx >= 0 else 0 | |
| entities.append( | |
| Entity( | |
| text=airline, | |
| label="AIRLINE", | |
| start=start, | |
| end=start + len(airline), | |
| confidence=0.95, | |
| ) | |
| ) | |
| # Extract flight number | |
| flight = item.get("Flight_Number", "") | |
| if flight and flight != "#N/A": | |
| entities.append( | |
| Entity( | |
| text=flight, | |
| label="FLIGHT_NUMBER", | |
| start=0, | |
| end=len(flight), | |
| confidence=0.92, | |
| ) | |
| ) | |
| # Extract dates | |
| date_str = item.get("Date_of_Event", "") | |
| if date_str: | |
| entities.append( | |
| Entity( | |
| text=date_str, | |
| label="DATE", | |
| start=0, | |
| end=len(date_str), | |
| confidence=0.90, | |
| ) | |
| ) | |
| results.append(EntityResult(reportId=f"row_{i}", entities=entities)) | |
| return results | |
| def generate_summary(self, data: List[Dict]) -> List[SummaryResult]: | |
| """Generate summaries using NLP service or fallback""" | |
| results = [] | |
| for i, item in enumerate(data): | |
| combined_text = ( | |
| item.get("Report", "") | |
| + " " | |
| + item.get("Root_Caused", "") | |
| + " " | |
| + item.get("Action_Taken", "") | |
| ) | |
| if self.nlp_service and len(combined_text) > 100: | |
| summary_result = self.nlp_service.summarize(combined_text) | |
| executive_summary = summary_result.get("executiveSummary", "") | |
| key_points = summary_result.get("keyPoints", []) | |
| else: | |
| category = item.get("Irregularity_Complain_Category", "Issue") | |
| report = item.get("Report", "")[:120] | |
| root_cause = item.get("Root_Caused", "")[:80] | |
| action = item.get("Action_Taken", "")[:80] | |
| executive_summary = f"{category}: {report}" | |
| if root_cause: | |
| executive_summary += f" Root cause: {root_cause}." | |
| key_points = [ | |
| f"Category: {category}", | |
| f"Status: {item.get('Status', 'Unknown')}", | |
| f"Area: {item.get('Area', 'Unknown')}", | |
| ] | |
| if action: | |
| key_points.append(f"Action: {action[:50]}...") | |
| results.append( | |
| SummaryResult( | |
| reportId=f"row_{i}", | |
| executiveSummary=executive_summary, | |
| keyPoints=key_points, | |
| ) | |
| ) | |
| return results | |
| def analyze_sentiment(self, data: List[Dict]) -> List[SentimentResult]: | |
| """Analyze sentiment/urgency using NLP service or fallback""" | |
| results = [] | |
| texts = [ | |
| item.get("Report", "") + " " + item.get("Root_Caused", "") for item in data | |
| ] | |
| if self.nlp_service: | |
| urgency_results = self.nlp_service.analyze_urgency(texts) | |
| else: | |
| urgency_results = self._analyze_urgency_fallback(texts) | |
| for i, (item, urg_result) in enumerate(zip(data, urgency_results)): | |
| results.append( | |
| SentimentResult( | |
| reportId=f"row_{i}", | |
| urgencyScore=urg_result.get("urgency_score", 0.0), | |
| sentiment=urg_result.get("sentiment", "Neutral"), | |
| keywords=urg_result.get("keywords", []), | |
| ) | |
| ) | |
| return results | |
| def _analyze_urgency_fallback(self, texts: List[str]) -> List[Dict]: | |
| """Fallback urgency analysis""" | |
| urgency_keywords = [ | |
| "damage", | |
| "broken", | |
| "emergency", | |
| "critical", | |
| "urgent", | |
| "torn", | |
| "severe", | |
| ] | |
| results = [] | |
| for text in texts: | |
| report = text.lower() | |
| keyword_matches = [kw for kw in urgency_keywords if kw in report] | |
| urgency_count = len(keyword_matches) | |
| urgency_score = min(1.0, urgency_count / 3.0) | |
| results.append( | |
| { | |
| "urgency_score": round(urgency_score, 2), | |
| "sentiment": "Negative" if urgency_score > 0.3 else "Neutral", | |
| "keywords": keyword_matches, | |
| } | |
| ) | |
| return results | |
| # Initialize model service | |
| model_service = ModelService() | |
| # ============== API Endpoints ============== | |
| async def root(): | |
| """Returns basic API status, version, and model availability.""" | |
| return { | |
| "status": "healthy", | |
| "service": "Gapura AI Analysis API", | |
| "version": "1.0.0", | |
| "models": { | |
| "regression": "loaded" if model_service.model_loaded else "unavailable", | |
| "nlp": model_service.nlp_service.version if model_service.nlp_service and model_service.nlp_service.models_loaded else "unavailable", | |
| }, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| async def health_check(): | |
| """ | |
| Returns detailed health status including: | |
| - **Models**: Version and load status of Regression and NLP models. | |
| - **Cache**: Redis/Local cache connectivity. | |
| - **Metrics**: Current model performance metrics (MAE, RMSE, R2). | |
| """ | |
| cache = get_cache() | |
| cache_health = cache.health_check() | |
| return { | |
| "status": "healthy", | |
| "models": { | |
| "regression": { | |
| "version": model_service.regression_version, | |
| "loaded": model_service.model_loaded, | |
| "metrics": model_service.model_metrics | |
| if model_service.model_loaded | |
| else None, | |
| }, | |
| "nlp": { | |
| "version": model_service.nlp_version, | |
| "status": "rule_based", | |
| }, | |
| }, | |
| "cache": cache_health, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| async def assess_risk( | |
| request: Optional[AnalysisRequest] = Body(None), | |
| sheetId: Optional[str] = None, | |
| sheetName: Optional[str] = None, | |
| rowRange: Optional[str] = None, | |
| bypass_cache: bool = False, | |
| top_k_patterns: int = 5, | |
| ): | |
| from data.sheets_service import GoogleSheetsService | |
| from data.action_service import get_action_service | |
| items_data: List[Dict[str, Any]] = [] | |
| if request and request.data: | |
| items_data = [r.model_dump(exclude_none=True) for r in request.data] | |
| elif sheetId and sheetName and rowRange: | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| items_data = sheets_service.fetch_sheet_data(sheetId, sheetName, rowRange, bypass_cache=bypass_cache) | |
| else: | |
| raise HTTPException(status_code=400, detail="sheetId, sheetName, and rowRange are required, or provide data in body") | |
| if len(items_data) == 0: | |
| return RiskAssessmentResponse(items=[], topPatterns=[], metadata={"count": 0}) | |
| preds = model_service.predict_regression(items_data) | |
| classes = model_service.classify_text(items_data) | |
| try: | |
| action_service = get_action_service() | |
| eff = action_service.action_effectiveness or {} | |
| except Exception: | |
| eff = {} | |
| items: List[RiskItem] = [] | |
| for i, item in enumerate(items_data): | |
| cat = item.get("Irregularity_Complain_Category", "Unknown") or "Unknown" | |
| hub = item.get("HUB", "Unknown") or "Unknown" | |
| area = (item.get("Area", "Unknown") or "Unknown").replace(" Area", "") | |
| pr = preds[i] | |
| cl = classes[i] | |
| sev = cl.severity | |
| sev_conf = cl.severityConfidence | |
| pdays = pr.predictedDays | |
| anom = 0.0 | |
| if pr.anomalyDetection: | |
| anom = pr.anomalyDetection.anomalyScore | |
| sev_s = _severity_to_score(sev) | |
| d_s = _normalize_days(pdays) | |
| cat_w = 1.0 - float(eff.get(cat, 0.8)) | |
| risk = min(1.0, 0.5 * sev_s + 0.25 * d_s + 0.15 * anom + 0.10 * cat_w) | |
| recs: List[Dict[str, Any]] = [] | |
| try: | |
| recs_resp = action_service.recommend( | |
| report=item.get("Report", "") or "", | |
| issue_type=cat, | |
| severity=sev, | |
| area=area if area else None, | |
| airline=item.get("Airlines") or None, | |
| top_n=5, | |
| ) | |
| recs = recs_resp.get("recommendations", []) | |
| except Exception: | |
| recs = [] | |
| prev = _extract_prevention([r.get("action", "") for r in recs]) | |
| items.append( | |
| RiskItem( | |
| reportId=f"row_{i}", | |
| severity=sev, | |
| severityConfidence=sev_conf, | |
| predictedDays=pdays, | |
| anomalyScore=anom, | |
| category=cat, | |
| hub=hub, | |
| area=area, | |
| riskScore=round(risk, 3), | |
| priority=_priority_from_score(risk), | |
| recommendedActions=recs[:5], | |
| preventiveSuggestions=prev, | |
| ) | |
| ) | |
| groups: Dict[str, Dict[str, Any]] = {} | |
| for it, raw in zip(items, items_data): | |
| key = f"{it.category}|{it.hub}|{it.area}" | |
| g = groups.get(key) or {"key": key, "category": it.category, "hub": it.hub, "area": it.area, "count": 0, "avgRisk": 0.0, "avgDays": 0.0, "highSeverityShare": 0.0} | |
| g["count"] += 1 | |
| g["avgRisk"] += it.riskScore | |
| g["avgDays"] += it.predictedDays | |
| g["highSeverityShare"] += 1.0 if it.severity in ("Critical", "High") else 0.0 | |
| groups[key] = g | |
| patterns = [] | |
| for g in groups.values(): | |
| c = g["count"] | |
| g["avgRisk"] = round(g["avgRisk"] / max(1, c), 3) | |
| g["avgDays"] = round(g["avgDays"] / max(1, c), 2) | |
| g["highSeverityShare"] = round(g["highSeverityShare"] / max(1, c), 3) | |
| patterns.append(g) | |
| patterns.sort(key=lambda x: (-x["avgRisk"], -x["highSeverityShare"], -x["avgDays"], -x["count"])) | |
| return RiskAssessmentResponse( | |
| items=sorted(items, key=lambda x: -x.riskScore), | |
| topPatterns=patterns[:top_k_patterns], | |
| metadata={"count": len(items)}, | |
| ) | |
| from data.job_service import JobService, JobStatus | |
| # Initialize job service | |
| job_service = JobService() | |
| def perform_analysis(data: List[Dict], options: AnalysisOptions, compact: bool) -> AnalysisResponse: | |
| """Core analysis logic reused by sync and async endpoints""" | |
| start_time = datetime.now() | |
| total_records = len(data) | |
| logger.info(f"Analyzing {total_records} records...") | |
| # Initialize response | |
| response = AnalysisResponse( | |
| metadata=Metadata( | |
| totalRecords=total_records, | |
| processingTime=0.0, | |
| modelVersions={ | |
| "regression": model_service.regression_version, | |
| "nlp": model_service.nlp_version, | |
| }, | |
| ) | |
| ) | |
| # Regression Analysis | |
| predictions: List[RegressionPrediction] = [] | |
| if options.predictResolutionTime or options.includeRisk: | |
| logger.info(f"Running regression analysis...") | |
| predictions = model_service.predict_regression(data) | |
| # Use real metrics if available | |
| if model_service.model_loaded and model_service.model_metrics: | |
| metrics = { | |
| "mae": round(model_service.model_metrics.get("test_mae", 1.2), 3), | |
| "rmse": round(model_service.model_metrics.get("test_rmse", 1.8), 3), | |
| "r2": round(model_service.model_metrics.get("test_r2", 0.78), 3), | |
| "model_loaded": True, | |
| "note": "Using trained model" | |
| if model_service.model_loaded | |
| else "Using fallback", | |
| } | |
| else: | |
| metrics = { | |
| "mae": None, | |
| "rmse": None, | |
| "r2": None, | |
| "model_loaded": False, | |
| "note": "Model not available - using fallback predictions", | |
| } | |
| if options.predictResolutionTime: | |
| response.regression = RegressionResult( | |
| predictions=predictions, | |
| modelMetrics=metrics, | |
| ) | |
| # NLP Analysis | |
| classifications: List[ClassificationResult] = [] | |
| if any( | |
| [ | |
| options.classifySeverity, | |
| options.extractEntities, | |
| options.generateSummary, | |
| options.includeRisk, | |
| ] | |
| ): | |
| logger.info(f"Running NLP analysis...") | |
| entities = [] | |
| summaries = [] | |
| sentiment = [] | |
| if options.classifySeverity or options.includeRisk: | |
| classifications = model_service.classify_text(data) | |
| if options.extractEntities: | |
| entities = model_service.extract_entities(data) | |
| if options.generateSummary: | |
| summaries = model_service.generate_summary(data) | |
| sentiment = model_service.analyze_sentiment(data) | |
| response.nlp = NLPResult( | |
| classifications=classifications, | |
| entities=entities, | |
| summaries=summaries, | |
| sentiment=sentiment, | |
| ) | |
| # Trend Analysis | |
| if options.analyzeTrends: | |
| logger.info(f"Running trend analysis...") | |
| by_airline = {} | |
| by_hub = {} | |
| by_category = {} | |
| for item in data: | |
| airline = item.get("Airlines", "Unknown") | |
| hub = item.get("HUB", "Unknown") | |
| category = item.get("Irregularity_Complain_Category", "Unknown") | |
| # Airline aggregation | |
| if airline not in by_airline: | |
| by_airline[airline] = {"count": 0, "issues": []} | |
| by_airline[airline]["count"] += 1 | |
| by_airline[airline]["issues"].append(category) | |
| # Hub aggregation | |
| if hub not in by_hub: | |
| by_hub[hub] = {"count": 0, "issues": []} | |
| by_hub[hub]["count"] += 1 | |
| by_hub[hub]["issues"].append(category) | |
| # Category aggregation | |
| if category not in by_category: | |
| by_category[category] = {"count": 0} | |
| by_category[category]["count"] += 1 | |
| # Convert to TrendData format | |
| by_airline_trend = { | |
| k: TrendData( | |
| count=v["count"], | |
| avgResolutionDays=2.0 + np.random.random(), | |
| topIssues=list(set(v["issues"]))[:3], | |
| ) | |
| for k, v in by_airline.items() | |
| } | |
| by_hub_trend = { | |
| k: TrendData( | |
| count=v["count"], | |
| avgResolutionDays=2.0 + np.random.random(), | |
| topIssues=list(set(v["issues"]))[:3], | |
| ) | |
| for k, v in by_hub.items() | |
| } | |
| by_category_trend = { | |
| k: {"count": v["count"], "trend": "stable"} | |
| for k, v in by_category.items() | |
| } | |
| response.trends = TrendResult( | |
| byAirline=by_airline_trend, | |
| byHub=by_hub_trend, | |
| byCategory=by_category_trend, | |
| timeSeries=[], | |
| ) | |
| # Risk Assessment | |
| if options.includeRisk: | |
| try: | |
| from data.action_service import get_action_service | |
| action_service = get_action_service() | |
| eff = action_service.action_effectiveness or {} | |
| except Exception: | |
| eff = {} | |
| action_service = None | |
| items: List[RiskItem] = [] | |
| for i, item in enumerate(data): | |
| cat = item.get("Irregularity_Complain_Category", "Unknown") or "Unknown" | |
| hub = item.get("HUB", "Unknown") or "Unknown" | |
| area = (item.get("Area", "Unknown") or "Unknown").replace(" Area", "") | |
| pr = predictions[i] if i < len(predictions) else None | |
| cl = classifications[i] if i < len(classifications) else None | |
| sev = cl.severity if cl else "Low" | |
| sev_conf = cl.severityConfidence if cl else 0.6 | |
| pdays = pr.predictedDays if pr else 0.0 | |
| anom = pr.anomalyDetection.anomalyScore if pr and pr.anomalyDetection else 0.0 | |
| sev_s = _severity_to_score(sev) | |
| d_s = _normalize_days(pdays) | |
| cat_w = 1.0 - float(eff.get(cat, 0.8)) | |
| risk = min(1.0, 0.5 * sev_s + 0.25 * d_s + 0.15 * anom + 0.10 * cat_w) | |
| recs: List[Dict[str, Any]] = [] | |
| if action_service: | |
| try: | |
| recs_resp = action_service.recommend( | |
| report=item.get("Report", "") or "", | |
| issue_type=cat, | |
| severity=sev, | |
| area=area if area else None, | |
| airline=item.get("Airlines") or None, | |
| top_n=5, | |
| ) | |
| recs = recs_resp.get("recommendations", []) | |
| except Exception: | |
| recs = [] | |
| prev = _extract_prevention([r.get("action", "") for r in recs]) | |
| items.append( | |
| RiskItem( | |
| reportId=f"row_{i}", | |
| severity=sev, | |
| severityConfidence=sev_conf, | |
| predictedDays=pdays, | |
| anomalyScore=anom, | |
| category=cat, | |
| hub=hub, | |
| area=area, | |
| riskScore=round(risk, 3), | |
| priority=_priority_from_score(risk), | |
| recommendedActions=recs[:5], | |
| preventiveSuggestions=prev, | |
| ) | |
| ) | |
| groups: Dict[str, Dict[str, Any]] = {} | |
| for it, raw in zip(items, data): | |
| key = f"{it.category}|{it.hub}|{it.area}" | |
| g = groups.get(key) or {"key": key, "category": it.category, "hub": it.hub, "area": it.area, "count": 0, "avgRisk": 0.0, "avgDays": 0.0, "highSeverityShare": 0.0} | |
| g["count"] += 1 | |
| g["avgRisk"] += it.riskScore | |
| g["avgDays"] += it.predictedDays | |
| g["highSeverityShare"] += 1.0 if it.severity in ("Critical", "High") else 0.0 | |
| groups[key] = g | |
| patterns = [] | |
| for g in groups.values(): | |
| c = g["count"] | |
| g["avgRisk"] = round(g["avgRisk"] / max(1, c), 3) | |
| g["avgDays"] = round(g["avgDays"] / max(1, c), 2) | |
| g["highSeverityShare"] = round(g["highSeverityShare"] / max(1, c), 3) | |
| patterns.append(g) | |
| patterns.sort(key=lambda x: (-x["avgRisk"], -x["highSeverityShare"], -x["avgDays"], -x["count"])) | |
| response.risk = RiskAssessmentResponse( | |
| items=sorted(items, key=lambda x: -x.riskScore), | |
| topPatterns=patterns[:5], | |
| metadata={"count": len(items)}, | |
| ) | |
| if compact: | |
| if response.regression and response.regression.predictions: | |
| for p in response.regression.predictions: | |
| p.shapExplanation = None | |
| p.anomalyDetection = None | |
| if response.nlp: | |
| response.nlp.entities = [] | |
| response.nlp.summaries = [] | |
| # Calculate processing time | |
| processing_time = (datetime.now() - start_time).total_seconds() * 1000 | |
| response.metadata.processingTime = round(processing_time, 2) | |
| logger.info(f"Analysis completed in {processing_time:.2f}ms") | |
| return response | |
| async def analyze_reports(request: AnalysisRequest, compact: bool = False): | |
| """ | |
| Perform comprehensive AI analysis on a batch of irregularity reports. | |
| - **Regression**: Predicts days to resolve based on category and description. | |
| - **NLP**: Classifies severity, extracts entities (Flight No, Airline), and summarizes text. | |
| - **Trends**: Aggregates data by Airline, Hub, and Category. | |
| The endpoint accepts a list of `IrregularityReport` objects. | |
| """ | |
| try: | |
| # Use direct data | |
| if not request.data: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No data provided. Either sheetId or data must be specified.", | |
| ) | |
| # Convert IrregularityReport objects to dicts | |
| data = [report.model_dump(exclude_none=True) for report in request.data] | |
| return perform_analysis(data, request.options, compact) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Analysis error: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def background_analysis_task(job_id: str, data: List[Dict], options: AnalysisOptions, compact: bool): | |
| """Background task for analysis""" | |
| try: | |
| job_service.update_job(job_id, JobStatus.PROCESSING) | |
| response = perform_analysis(data, options, compact) | |
| job_service.update_job(job_id, JobStatus.COMPLETED, result=response.model_dump()) | |
| except Exception as e: | |
| logger.error(f"Job {job_id} failed: {e}") | |
| job_service.update_job(job_id, JobStatus.FAILED, error=str(e)) | |
| async def analyze_async( | |
| request: AnalysisRequest, background_tasks: BackgroundTasks, compact: bool = False | |
| ): | |
| """ | |
| Start a background analysis job for large datasets. | |
| Returns a `jobId` immediately, which can be used to poll status via `/api/ai/jobs/{jobId}`. | |
| """ | |
| if not request.data: | |
| raise HTTPException(status_code=400, detail="No data provided") | |
| data = [report.model_dump(exclude_none=True) for report in request.data] | |
| job_id = job_service.create_job() | |
| background_tasks.add_task(background_analysis_task, job_id, data, request.options, compact) | |
| return {"job_id": job_id, "status": "queued"} | |
| async def get_job_status(job_id: str): | |
| """ | |
| Retrieve the status and results of a background analysis job. | |
| Possible statuses: `queued`, `processing`, `completed`, `failed`. | |
| """ | |
| job = job_service.get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return job | |
| async def predict_single(report: IrregularityReport): | |
| """ | |
| Get immediate AI predictions for a single irregularity report. | |
| Useful for real-time validation or "what-if" analysis in the UI. | |
| """ | |
| try: | |
| report_dict = report.model_dump(exclude_none=True) | |
| predictions = model_service.predict_regression([report_dict]) | |
| classifications = model_service.classify_text([report_dict]) | |
| entities = model_service.extract_entities([report_dict]) | |
| summaries = model_service.generate_summary([report_dict]) | |
| sentiment = model_service.analyze_sentiment([report_dict]) | |
| return { | |
| "prediction": predictions[0], | |
| "classification": classifications[0], | |
| "entities": entities[0], | |
| "summary": summaries[0], | |
| "sentiment": sentiment[0], | |
| "modelLoaded": model_service.model_loaded, | |
| } | |
| except Exception as e: | |
| logger.error(f"Single prediction error: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def train_models(background_tasks: BackgroundTasks, force: bool = False): | |
| """ | |
| Trigger a background task to retrain AI models. | |
| Checks if new data is available in Google Sheets before training, unless `force=True`. | |
| """ | |
| from scripts.scheduled_training import TrainingScheduler | |
| def run_training_task(): | |
| scheduler = TrainingScheduler() | |
| result = scheduler.run_training(force=force) | |
| logger.info(f"Training completed: {result}") | |
| background_tasks.add_task(run_training_task) | |
| return { | |
| "status": "training_queued", | |
| "message": "Model retraining has been started in the background", | |
| "force": force, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| async def training_status(): | |
| """ | |
| Get the status of the latest training job and training history. | |
| """ | |
| from scripts.scheduled_training import TrainingScheduler | |
| scheduler = TrainingScheduler() | |
| status = scheduler.get_status() | |
| return { | |
| "status": "success", | |
| "data": status, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| async def model_info(): | |
| """Get current model information""" | |
| return { | |
| "regression": { | |
| "version": model_service.regression_version, | |
| "type": "GradientBoostingRegressor", | |
| "status": "loaded" if model_service.model_loaded else "unavailable", | |
| "last_trained": "2025-01-15", | |
| "metrics": model_service.model_metrics | |
| if model_service.model_loaded | |
| else None, | |
| }, | |
| "nlp": { | |
| "version": model_service.nlp_version, | |
| "type": "Rule-based + Keyword extraction", | |
| "status": "active", | |
| "tasks": ["classification", "ner", "summarization", "sentiment"], | |
| "note": "Full ML NLP models coming soon", | |
| }, | |
| } | |
| async def invalidate_cache(sheet_name: Optional[str] = None): | |
| """Invalidate cache for sheets data""" | |
| cache = get_cache() | |
| if sheet_name: | |
| pattern = f"sheets:*{sheet_name}*" | |
| deleted = cache.delete_pattern(pattern) | |
| return { | |
| "status": "success", | |
| "message": f"Invalidated cache for sheet: {sheet_name}", | |
| "keys_deleted": deleted, | |
| } | |
| else: | |
| deleted = cache.delete_pattern("sheets:*") | |
| return { | |
| "status": "success", | |
| "message": "Invalidated all sheets cache", | |
| "keys_deleted": deleted, | |
| } | |
| async def cache_status(): | |
| """Get cache status and statistics""" | |
| cache = get_cache() | |
| return cache.health_check() | |
| class AnalyzeAllResponse(BaseModel): | |
| status: str | |
| metadata: Dict[str, Any] | |
| sheets: Dict[str, Any] | |
| results: List[Dict[str, Any]] | |
| summary: Dict[str, Any] | |
| timestamp: str | |
| async def analyze_all_sheets( | |
| bypass_cache: bool = False, | |
| include_regression: bool = True, | |
| include_nlp: bool = True, | |
| include_trends: bool = True, | |
| max_rows_per_sheet: int = 10000, | |
| compact: bool = False, | |
| ): | |
| """ | |
| Analyze ALL rows from all Google Sheets | |
| Fetches data from both NON CARGO and CGO sheets, analyzes each row, | |
| and returns comprehensive results. | |
| Args: | |
| bypass_cache: Skip cache and fetch fresh data | |
| include_regression: Include regression predictions | |
| include_nlp: Include NLP analysis (severity, entities, summary) | |
| include_trends: Include trend analysis | |
| max_rows_per_sheet: Maximum rows to process per sheet | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException( | |
| status_code=500, detail="GOOGLE_SHEET_ID not configured" | |
| ) | |
| all_data = [] | |
| sheet_info = {} | |
| sheets_to_fetch = [ | |
| {"name": "NON CARGO", "range": f"A1:AA{max_rows_per_sheet + 1}"}, | |
| {"name": "CGO", "range": f"A1:Z{max_rows_per_sheet + 1}"}, | |
| ] | |
| for sheet in sheets_to_fetch: | |
| try: | |
| sheet_name = sheet["name"] | |
| range_str = sheet["range"] | |
| logger.info(f"Fetching {sheet_name}...") | |
| data = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, sheet_name, range_str, bypass_cache=bypass_cache | |
| ) | |
| for row in data: | |
| row["_source_sheet"] = sheet_name | |
| all_data.append(row) | |
| sheet_info[sheet_name] = { | |
| "rows_fetched": len(data), | |
| "status": "success", | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to fetch {sheet['name']}: {e}") | |
| sheet_info[sheet["name"]] = { | |
| "rows_fetched": 0, | |
| "status": "error", | |
| "error": str(e), | |
| } | |
| total_records = len(all_data) | |
| if total_records == 0: | |
| raise HTTPException(status_code=404, detail="No data found in any sheet") | |
| logger.info(f"Analyzing {total_records} total records...") | |
| results = [] | |
| batch_size = 100 | |
| for i in range(0, total_records, batch_size): | |
| batch = all_data[i : i + batch_size] | |
| if include_regression: | |
| regression_preds = model_service.predict_regression(batch) | |
| else: | |
| regression_preds = [None] * len(batch) | |
| if include_nlp: | |
| classifications = model_service.classify_text(batch) | |
| entities = model_service.extract_entities(batch) | |
| summaries = model_service.generate_summary(batch) | |
| sentiments = model_service.analyze_sentiment(batch) | |
| else: | |
| classifications = [None] * len(batch) | |
| entities = [None] * len(batch) | |
| summaries = [None] * len(batch) | |
| sentiments = [None] * len(batch) | |
| for j, row in enumerate(batch): | |
| result = { | |
| "rowId": row.get("_row_id", f"row_{i + j}"), | |
| "sourceSheet": row.get("_source_sheet", "Unknown"), | |
| "originalData": { | |
| "date": row.get("Date_of_Event"), | |
| "airline": row.get("Airlines"), | |
| "flightNumber": row.get("Flight_Number"), | |
| "branch": row.get("Branch"), | |
| "hub": row.get("HUB"), | |
| "route": row.get("Route"), | |
| "category": row.get("Report_Category"), | |
| "issueType": row.get("Irregularity_Complain_Category"), | |
| "report": row.get("Report"), | |
| "status": row.get("Status"), | |
| }, | |
| } | |
| if regression_preds[j]: | |
| pred = { | |
| "predictedDays": regression_preds[j].predictedDays, | |
| "confidenceInterval": regression_preds[j].confidenceInterval, | |
| "hasUnknownCategories": regression_preds[j].hasUnknownCategories, | |
| } | |
| if not compact: | |
| pred["shapExplanation"] = ( | |
| regression_preds[j].shapExplanation.model_dump() | |
| if regression_preds[j].shapExplanation | |
| else None | |
| ) | |
| pred["anomalyDetection"] = ( | |
| regression_preds[j].anomalyDetection.model_dump() | |
| if regression_preds[j].anomalyDetection | |
| else None | |
| ) | |
| result["prediction"] = pred | |
| if classifications[j]: | |
| result["classification"] = classifications[j].model_dump() | |
| if entities[j] and not compact: | |
| result["entities"] = entities[j].model_dump() | |
| if summaries[j] and not compact: | |
| result["summary"] = summaries[j].model_dump() | |
| if sentiments[j] and not compact: | |
| result["sentiment"] = sentiments[j].model_dump() | |
| results.append(result) | |
| summary = { | |
| "totalRecords": total_records, | |
| "sheetsProcessed": len( | |
| [s for s in sheet_info.values() if s["status"] == "success"] | |
| ), | |
| "regressionEnabled": include_regression, | |
| "nlpEnabled": include_nlp, | |
| } | |
| if include_nlp and results: | |
| severity_counts = {} | |
| for r in results: | |
| sev = r.get("classification", {}).get("severity", "Unknown") | |
| severity_counts[sev] = severity_counts.get(sev, 0) + 1 | |
| summary["severityDistribution"] = severity_counts | |
| if include_regression and results: | |
| predictions = [ | |
| r["prediction"]["predictedDays"] for r in results if r.get("prediction") | |
| ] | |
| if predictions: | |
| summary["predictionStats"] = { | |
| "min": round(min(predictions), 2), | |
| "max": round(max(predictions), 2), | |
| "mean": round(sum(predictions) / len(predictions), 2), | |
| } | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| return AnalyzeAllResponse( | |
| status="success", | |
| metadata={ | |
| "totalRecords": total_records, | |
| "processingTimeSeconds": round(processing_time, 2), | |
| "recordsPerSecond": round(total_records / processing_time, 2) | |
| if processing_time > 0 | |
| else 0, | |
| "modelVersions": { | |
| "regression": model_service.regression_version, | |
| "nlp": model_service.nlp_version, | |
| }, | |
| }, | |
| sheets=sheet_info, | |
| results=results, | |
| summary=summary, | |
| timestamp=datetime.now().isoformat(), | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Analyze all error: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============== Risk Scoring Endpoints ============== | |
| async def risk_summary(): | |
| """Get overall risk summary for all entities""" | |
| from data.risk_service import get_risk_service | |
| risk_service = get_risk_service() | |
| return risk_service.get_risk_summary() | |
| async def airline_risks(): | |
| """Get risk scores for all airlines""" | |
| from data.risk_service import get_risk_service | |
| risk_service = get_risk_service() | |
| return risk_service.get_all_airline_risks() | |
| async def airline_risk(airline_name: str): | |
| """Get risk score for a specific airline""" | |
| from data.risk_service import get_risk_service | |
| risk_service = get_risk_service() | |
| risk_data = risk_service.get_airline_risk(airline_name) | |
| if not risk_data: | |
| raise HTTPException( | |
| status_code=404, detail=f"Airline '{airline_name}' not found" | |
| ) | |
| recommendations = risk_service.get_risk_recommendations("airline", airline_name) | |
| return { | |
| "airline": airline_name, | |
| "risk_data": risk_data, | |
| "recommendations": recommendations, | |
| } | |
| async def branch_risks(): | |
| """Get risk scores for all branches""" | |
| from data.risk_service import get_risk_service | |
| risk_service = get_risk_service() | |
| return risk_service.get_all_branch_risks() | |
| async def hub_risks(): | |
| """Get risk scores for all hubs""" | |
| from data.risk_service import get_risk_service | |
| risk_service = get_risk_service() | |
| return risk_service.get_all_hub_risks() | |
| async def calculate_risk_scores(bypass_cache: bool = False): | |
| """Calculate risk scores from current Google Sheets data""" | |
| from data.risk_service import get_risk_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| # Fetch all data | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| risk_service = get_risk_service() | |
| risk_data = risk_service.calculate_all_risk_scores(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| "risk_summary": risk_service.get_risk_summary(), | |
| } | |
| # ============== Subcategory Classification Endpoints ============== | |
| async def classify_subcategory( | |
| report: str, | |
| area: Optional[str] = None, | |
| issue_type: Optional[str] = None, | |
| root_cause: Optional[str] = None, | |
| ): | |
| """Classify report into subcategory""" | |
| from data.subcategory_service import get_subcategory_classifier | |
| classifier = get_subcategory_classifier() | |
| result = classifier.classify(report, area, issue_type, root_cause) | |
| return result | |
| async def get_subcategories(area: Optional[str] = None): | |
| """Get list of available subcategories""" | |
| from data.subcategory_service import get_subcategory_classifier | |
| classifier = get_subcategory_classifier() | |
| return classifier.get_available_categories(area) | |
| # ============== Action Recommendation Endpoints ============== | |
| async def recommend_actions( | |
| report: str, | |
| issue_type: str, | |
| severity: str = "Medium", | |
| area: Optional[str] = None, | |
| airline: Optional[str] = None, | |
| top_n: int = 5, | |
| ): | |
| """Get action recommendations for an issue""" | |
| from data.action_service import get_action_service | |
| action_service = get_action_service() | |
| recommendations = action_service.recommend( | |
| report=report, | |
| issue_type=issue_type, | |
| severity=severity, | |
| area=area, | |
| airline=airline, | |
| top_n=top_n, | |
| ) | |
| return recommendations | |
| async def train_action_recommender( | |
| bypass_cache: bool = False, background_tasks: BackgroundTasks = None | |
| ): | |
| """Train action recommender from historical data""" | |
| from data.action_service import get_action_service | |
| from data.sheets_service import GoogleSheetsService | |
| from data.similarity_service import get_similarity_service | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| similarity_service = get_similarity_service() | |
| similarity_service.build_index(all_data) | |
| action_service = get_action_service() | |
| action_service.train_from_data(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| } | |
| # ============== Advanced NER Endpoints ============== | |
| async def extract_entities(text: str): | |
| """Extract entities from text""" | |
| from data.advanced_ner_service import get_advanced_ner | |
| ner = get_advanced_ner() | |
| entities = ner.extract(text) | |
| summary = ner.extract_summary(text) | |
| return { | |
| "entities": entities, | |
| "summary": summary, | |
| } | |
| # ============== Similarity Endpoints ============== | |
| async def find_similar_reports( | |
| text: str, | |
| top_k: int = 5, | |
| threshold: float = 0.3, | |
| ): | |
| """Find similar reports""" | |
| from data.similarity_service import get_similarity_service | |
| similarity_service = get_similarity_service() | |
| similar = similarity_service.find_similar(text, top_k, threshold) | |
| return { | |
| "query_preview": text[:100], | |
| "similar_reports": similar, | |
| } | |
| async def build_similarity_index(bypass_cache: bool = False): | |
| """Build similarity index from Google Sheets data""" | |
| from data.similarity_service import get_similarity_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| similarity_service = get_similarity_service() | |
| similarity_service.build_index(all_data) | |
| return { | |
| "status": "success", | |
| "records_indexed": len(all_data), | |
| } | |
| # ============== Forecasting Endpoints ============== | |
| async def forecast_issues(periods: int = 4): | |
| """Forecast issue volume for next periods""" | |
| from data.forecast_service import get_forecast_service | |
| forecast_service = get_forecast_service() | |
| forecast = forecast_service.forecast_issues(periods) | |
| return forecast | |
| async def predict_trends(): | |
| """Predict category trends""" | |
| from data.forecast_service import get_forecast_service | |
| forecast_service = get_forecast_service() | |
| trends = forecast_service.predict_category_trends() | |
| return trends | |
| async def get_seasonal_patterns(): | |
| """Get seasonal patterns""" | |
| from data.forecast_service import get_forecast_service | |
| forecast_service = get_forecast_service() | |
| patterns = forecast_service.get_seasonal_patterns() | |
| return patterns | |
| async def build_forecast_data(bypass_cache: bool = False): | |
| """Build forecast historical data from Google Sheets""" | |
| from data.forecast_service import get_forecast_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| forecast_service = get_forecast_service() | |
| forecast_service.build_historical_data(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| "forecast_summary": forecast_service.get_forecast_summary(), | |
| } | |
| # ============== Report Generation Endpoints ============== | |
| async def generate_report( | |
| row_id: str, | |
| bypass_cache: bool = False, | |
| ): | |
| """Generate formal incident report""" | |
| from data.report_generator_service import get_report_generator | |
| from data.sheets_service import GoogleSheetsService | |
| from data.risk_service import get_risk_service | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| # Fetch all data and find the record | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| record = None | |
| for r in all_data: | |
| if r.get("_row_id") == row_id: | |
| record = r | |
| break | |
| if not record: | |
| raise HTTPException(status_code=404, detail=f"Record '{row_id}' not found") | |
| # Generate analysis | |
| report_text = record.get("Report", "") + " " + record.get("Root_Caused", "") | |
| analysis = { | |
| "severity": model_service._classify_severity_fallback([report_text])[0].get( | |
| "severity", "Medium" | |
| ), | |
| "issueType": record.get("Irregularity_Complain_Category", ""), | |
| } | |
| # Get risk data | |
| risk_service = get_risk_service() | |
| airline = record.get("Airlines", "") | |
| risk_data = risk_service.get_airline_risk(airline) | |
| # Generate report | |
| report_gen = get_report_generator() | |
| formal_report = report_gen.generate_incident_report(record, analysis, risk_data) | |
| exec_summary = report_gen.generate_executive_summary(record, analysis) | |
| json_report = report_gen.generate_json_report(record, analysis, risk_data) | |
| return { | |
| "row_id": row_id, | |
| "formal_report": formal_report, | |
| "executive_summary": exec_summary, | |
| "structured_report": json_report, | |
| } | |
| # ============== Dashboard Endpoints ============== | |
| async def dashboard_summary(bypass_cache: bool = False): | |
| """Get comprehensive dashboard summary""" | |
| from data.risk_service import get_risk_service | |
| from data.forecast_service import get_forecast_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| # Fetch data | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| # Get risk summary | |
| risk_service = get_risk_service() | |
| risk_summary = risk_service.get_risk_summary() | |
| # Get forecast summary | |
| forecast_service = get_forecast_service() | |
| forecast_summary = forecast_service.get_forecast_summary() | |
| # Calculate statistics | |
| severity_dist = Counter() | |
| category_dist = Counter() | |
| airline_dist = Counter() | |
| for record in all_data: | |
| report_text = record.get("Report", "") + " " + record.get("Root_Caused", "") | |
| sev = model_service._classify_severity_fallback([report_text])[0].get( | |
| "severity", "Low" | |
| ) | |
| severity_dist[sev] += 1 | |
| category_dist[record.get("Irregularity_Complain_Category", "Unknown")] += 1 | |
| airline_dist[record.get("Airlines", "Unknown")] += 1 | |
| return { | |
| "total_records": len(all_data), | |
| "sheets": { | |
| "non_cargo": len(non_cargo), | |
| "cargo": len(cargo), | |
| }, | |
| "severity_distribution": dict(severity_dist), | |
| "category_distribution": dict(category_dist.most_common(10)), | |
| "top_airlines": dict(airline_dist.most_common(10)), | |
| "risk_summary": risk_summary, | |
| "forecast_summary": forecast_summary, | |
| "model_status": { | |
| "regression": model_service.model_loaded, | |
| "nlp": model_service.nlp_service is not None, | |
| }, | |
| "last_updated": datetime.now().isoformat(), | |
| } | |
| # ============== Seasonality Endpoints ============== | |
| async def seasonality_summary(category_type: Optional[str] = None): | |
| """ | |
| Get seasonality summary and patterns | |
| Args: | |
| category_type: "landside_airside", "cgo", or None for both | |
| """ | |
| from data.seasonality_service import get_seasonality_service | |
| service = get_seasonality_service() | |
| return service.get_seasonality_summary(category_type) | |
| async def seasonality_forecast( | |
| category_type: Optional[str] = None, | |
| periods: int = 4, | |
| granularity: str = "weekly", | |
| ): | |
| """ | |
| Forecast issue volumes | |
| Args: | |
| category_type: "landside_airside", "cgo", or None for both | |
| periods: Number of periods to forecast | |
| granularity: "daily", "weekly", or "monthly" | |
| """ | |
| from data.seasonality_service import get_seasonality_service | |
| service = get_seasonality_service() | |
| return service.forecast(category_type, periods, granularity) | |
| async def seasonality_peaks( | |
| category_type: Optional[str] = None, threshold: float = 1.2 | |
| ): | |
| """ | |
| Identify peak periods | |
| Args: | |
| category_type: "landside_airside", "cgo", or None for both | |
| threshold: Multiplier above average (1.2 = 20% above) | |
| """ | |
| from data.seasonality_service import get_seasonality_service | |
| service = get_seasonality_service() | |
| return service.get_peak_periods(category_type, threshold) | |
| async def build_seasonality_patterns(bypass_cache: bool = False): | |
| """Build seasonality patterns from Google Sheets data""" | |
| from data.seasonality_service import get_seasonality_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| for row in non_cargo: | |
| row["_sheet_name"] = "NON CARGO" | |
| for row in cargo: | |
| row["_sheet_name"] = "CGO" | |
| all_data = non_cargo + cargo | |
| service = get_seasonality_service() | |
| result = service.build_patterns(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| "patterns": result, | |
| } | |
| # ============== Root Cause Endpoints ============== | |
| async def classify_root_cause( | |
| root_cause: str, | |
| report: Optional[str] = None, | |
| area: Optional[str] = None, | |
| category: Optional[str] = None, | |
| ): | |
| """ | |
| Classify a root cause text into categories | |
| Categories: Equipment Failure, Staff Competency, Process/Procedure, | |
| Communication, External Factors, Documentation, Training Gap, Resource/Manpower | |
| """ | |
| from data.root_cause_service import get_root_cause_service | |
| service = get_root_cause_service() | |
| context = {"area": area, "category": category} | |
| result = service.classify(root_cause, report or "", context) | |
| return result | |
| async def classify_root_cause_batch(bypass_cache: bool = False): | |
| """Classify root causes for all records""" | |
| from data.root_cause_service import get_root_cause_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| service = get_root_cause_service() | |
| results = service.classify_batch(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| "classifications": results[:100], | |
| "total_classified": len( | |
| [r for r in results if r["primary_category"] != "Unknown"] | |
| ), | |
| } | |
| async def get_root_cause_categories(): | |
| """Get all available root cause categories""" | |
| from data.root_cause_service import get_root_cause_service | |
| service = get_root_cause_service() | |
| return service.get_categories() | |
| async def get_root_cause_stats(source: Optional[str] = None, bypass_cache: bool = False): | |
| """ | |
| Get root cause statistics from data | |
| Args: | |
| source: "NON CARGO", "CGO", or None for both | |
| bypass_cache: Skip cache and fetch fresh data | |
| """ | |
| from data.root_cause_service import get_root_cause_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| all_data = [] | |
| # Conditional fetching based on source to reduce I/O and processing | |
| if not source or source.upper() == "NON CARGO": | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| all_data.extend(non_cargo) | |
| if not source or source.upper() == "CGO": | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| all_data.extend(cargo) | |
| service = get_root_cause_service() | |
| stats = service.get_statistics(all_data) | |
| return stats | |
| async def train_root_cause_classifier(background_tasks: BackgroundTasks, bypass_cache: bool = False): | |
| """Train root cause classifier from historical data""" | |
| from data.root_cause_service import get_root_cause_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| all_data = non_cargo + cargo | |
| service = get_root_cause_service() | |
| # Offload the intensive training process to the background | |
| background_tasks.add_task(service.train_from_data, all_data) | |
| return { | |
| "status": "training_started", | |
| "records_fetched": len(all_data), | |
| "message": "Classification training is now running in the background. The model will be automatically updated once complete." | |
| } | |
| # ============== Category Summarization Endpoints ============== | |
| class CategorySummaryResponse(BaseModel): | |
| status: str | |
| category_type: str | |
| summary: Dict[str, Any] | |
| timestamp: str | |
| async def summarize_by_category(category: str = "all", bypass_cache: bool = False): | |
| """ | |
| Get summarized insights for Non-cargo and/or CGO categories | |
| Query Parameters: | |
| category: "non_cargo", "cgo", or "all" (default: "all") | |
| bypass_cache: Skip cache and fetch fresh data (default: false) | |
| Returns aggregated summary including: | |
| - Severity distribution | |
| - Top categories, airlines, hubs, branches | |
| - Key insights and recommendations | |
| - Common issues | |
| - Monthly trends | |
| """ | |
| from data.category_summarization_service import get_category_summarization_service | |
| from data.sheets_service import GoogleSheetsService | |
| valid_categories = ["non_cargo", "cgo", "all"] | |
| if category.lower() not in valid_categories: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid category. Must be one of: {valid_categories}", | |
| ) | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| for row in non_cargo: | |
| row["_sheet_name"] = "NON CARGO" | |
| for row in cargo: | |
| row["_sheet_name"] = "CGO" | |
| all_data = non_cargo + cargo | |
| summarization_service = get_category_summarization_service() | |
| summary = summarization_service.summarize_category(all_data, category.lower()) | |
| return CategorySummaryResponse( | |
| status="success", | |
| category_type=category.lower(), | |
| summary=summary, | |
| timestamp=datetime.now().isoformat(), | |
| ) | |
| async def summarize_non_cargo(bypass_cache: bool = False): | |
| """Quick endpoint for Non-cargo summary""" | |
| return await summarize_by_category(category="non_cargo", bypass_cache=bypass_cache) | |
| async def summarize_cgo(bypass_cache: bool = False): | |
| """Quick endpoint for CGO (Cargo) summary""" | |
| return await summarize_by_category(category="cgo", bypass_cache=bypass_cache) | |
| async def compare_categories(bypass_cache: bool = False): | |
| """Compare Non-cargo and CGO categories side by side""" | |
| return await summarize_by_category(category="all", bypass_cache=bypass_cache) | |
| # ============== Branch Analytics Endpoints ============== | |
| async def branch_analytics_summary(category_type: Optional[str] = None): | |
| """ | |
| Get branch analytics summary | |
| Args: | |
| category_type: "landside_airside", "cgo", or None for both | |
| """ | |
| from data.branch_analytics_service import get_branch_analytics_service | |
| service = get_branch_analytics_service() | |
| return service.get_summary(category_type) | |
| async def get_branch_metrics(branch_name: str, category_type: Optional[str] = None): | |
| """ | |
| Get metrics for a specific branch | |
| Args: | |
| branch_name: Branch name | |
| category_type: "landside_airside", "cgo", or None for combined | |
| """ | |
| from data.branch_analytics_service import get_branch_analytics_service | |
| service = get_branch_analytics_service() | |
| data = service.get_branch(branch_name, category_type) | |
| if not data: | |
| raise HTTPException(status_code=404, detail=f"Branch '{branch_name}' not found") | |
| return data | |
| async def branch_ranking( | |
| category_type: Optional[str] = None, | |
| sort_by: str = "risk_score", | |
| limit: int = 20, | |
| ): | |
| """ | |
| Get branch ranking | |
| Args: | |
| category_type: "landside_airside", "cgo", or None for both | |
| sort_by: Field to sort by (risk_score, total_issues, critical_high_count) | |
| limit: Maximum branches to return | |
| """ | |
| from data.branch_analytics_service import get_branch_analytics_service | |
| service = get_branch_analytics_service() | |
| return service.get_ranking(category_type, sort_by, limit) | |
| async def branch_comparison(): | |
| """Compare all branches across category types""" | |
| from data.branch_analytics_service import get_branch_analytics_service | |
| service = get_branch_analytics_service() | |
| return service.get_comparison() | |
| async def calculate_branch_metrics(bypass_cache: bool = False): | |
| """Calculate branch metrics from Google Sheets data""" | |
| from data.branch_analytics_service import get_branch_analytics_service | |
| from data.sheets_service import GoogleSheetsService | |
| cache = get_cache() if not bypass_cache else None | |
| sheets_service = GoogleSheetsService(cache=cache) | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured") | |
| non_cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache | |
| ) | |
| cargo = sheets_service.fetch_sheet_data( | |
| spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache | |
| ) | |
| for row in non_cargo: | |
| row["_sheet_name"] = "NON CARGO" | |
| for row in cargo: | |
| row["_sheet_name"] = "CGO" | |
| all_data = non_cargo + cargo | |
| service = get_branch_analytics_service() | |
| result = service.calculate_branch_metrics(all_data) | |
| return { | |
| "status": "success", | |
| "records_processed": len(all_data), | |
| "metrics": result, | |
| } | |
| # ============== Main ============== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("API_PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |