gapura-ai / api /main.py
Muhammad Ridzki Nugraha
Deploy Gapura AI update (exclude models; models pulled at runtime)
d27bf31 verified
"""
Gapura AI Analysis API
FastAPI server for regression and NLP analysis of irregularity reports
Uses real trained models from ai-model/models/
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
from pydantic_core import ValidationError
from typing import List, Optional, Dict, Any, Tuple
from collections import Counter
import os
import json
import logging
import hashlib
from datetime import datetime
import numpy as np
import pickle
import pandas as pd
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.cache_service import get_cache, CacheService
from data.nlp_service import NLPModelService
from data.shap_service import get_shap_explainer
from data.anomaly_service import get_anomaly_detector
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Gapura AI Analysis API",
description="AI-powered analysis for irregularity reports",
version="1.0.0",
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(ValidationError)
async def validation_exception_handler(request: Request, exc: ValidationError):
return JSONResponse(
status_code=422,
content={
"detail": "Validation error",
"errors": exc.errors(),
"body": exc.json(),
},
)
# ============== Pydantic Models ==============
from enum import Enum
from datetime import date as date_type
class ReportCategoryEnum(str, Enum):
IRREGULARITY = "Irregularity"
COMPLAINT = "Complaint"
class AreaEnum(str, Enum):
APRON = "Apron Area"
TERMINAL = "Terminal Area"
GENERAL = "General"
class StatusEnum(str, Enum):
OPEN = "Open"
CLOSED = "Closed"
IN_PROGRESS = "In Progress"
class IrregularityReport(BaseModel):
Date_of_Event: Optional[str] = Field(None, description="Date of the event")
Airlines: Optional[str] = Field(None, max_length=100)
Flight_Number: Optional[str] = Field(None, max_length=20)
Branch: Optional[str] = Field(None, max_length=10)
HUB: Optional[str] = Field(None, max_length=20)
Route: Optional[str] = Field(None, max_length=50)
Report_Category: Optional[str] = Field(None, max_length=50)
Irregularity_Complain_Category: Optional[str] = Field(None, max_length=100)
Report: Optional[str] = Field(None, max_length=2000)
Root_Caused: Optional[str] = Field(None, max_length=2000)
Action_Taken: Optional[str] = Field(None, max_length=2000)
Area: Optional[str] = Field(None, max_length=50)
Status: Optional[str] = Field(None, max_length=50)
Reported_By: Optional[str] = Field(None, max_length=100)
Upload_Irregularity_Photo: Optional[str] = Field(None)
model_config = {"extra": "allow"}
class AnalysisOptions(BaseModel):
predictResolutionTime: bool = Field(
default=True, description="Run regression model"
)
classifySeverity: bool = Field(
default=True, description="Classify severity using NLP"
)
extractEntities: bool = Field(
default=True, description="Extract entities using NER"
)
generateSummary: bool = Field(default=True, description="Generate text summaries")
analyzeTrends: bool = Field(default=True, description="Analyze trends")
bypassCache: bool = Field(
default=False, description="Bypass cache and fetch fresh data"
)
class AnalysisRequest(BaseModel):
sheetId: Optional[str] = Field(None, description="Google Sheet ID")
sheetName: Optional[str] = Field(None, description="Sheet name (NON CARGO or CGO)")
rowRange: Optional[str] = Field(None, description="Row range (e.g., A2:Z100)")
data: Optional[List[IrregularityReport]] = Field(
None, description="Direct data upload"
)
options: AnalysisOptions = Field(default_factory=AnalysisOptions)
@field_validator("data")
@classmethod
def validate_data(cls, v):
if v is not None and len(v) == 0:
raise ValueError("data array cannot be empty")
return v
class ShapExplanation(BaseModel):
baseValue: float = Field(description="Base/expected value from model")
predictionExplained: bool = Field(
description="Whether SHAP explanation is available"
)
topFactors: List[Dict[str, Any]] = Field(
default_factory=list, description="Top contributing features"
)
explanation: str = Field(default="", description="Human-readable explanation")
class AnomalyResult(BaseModel):
isAnomaly: bool = Field(description="Whether prediction is anomalous")
anomalyScore: float = Field(description="Anomaly score (0-1)")
anomalies: List[Dict[str, Any]] = Field(
default_factory=list, description="List of detected anomalies"
)
class RegressionPrediction(BaseModel):
reportId: str
predictedDays: float
confidenceInterval: Tuple[float, float]
featureImportance: Dict[str, float]
hasUnknownCategories: bool = Field(
default=False, description="True if unknown categories were used in prediction"
)
shapExplanation: Optional[ShapExplanation] = Field(
default=None, description="SHAP-based explanation for prediction"
)
anomalyDetection: Optional[AnomalyResult] = Field(
default=None, description="Anomaly detection results"
)
class RegressionResult(BaseModel):
predictions: List[RegressionPrediction]
modelMetrics: Dict[str, Any]
class ClassificationResult(BaseModel):
reportId: str
severity: str
severityConfidence: float
areaType: str
issueType: str
issueTypeConfidence: float
class Entity(BaseModel):
text: str
label: str
start: int
end: int
confidence: float
class EntityResult(BaseModel):
reportId: str
entities: List[Entity]
class SummaryResult(BaseModel):
reportId: str
executiveSummary: str
keyPoints: List[str]
class SentimentResult(BaseModel):
reportId: str
urgencyScore: float
sentiment: str
keywords: List[str]
class NLPResult(BaseModel):
classifications: List[ClassificationResult]
entities: List[EntityResult]
summaries: List[SummaryResult]
sentiment: List[SentimentResult]
class TrendData(BaseModel):
count: int
avgResolutionDays: Optional[float]
topIssues: List[str]
class TrendResult(BaseModel):
byAirline: Dict[str, TrendData]
byHub: Dict[str, TrendData]
byCategory: Dict[str, Dict[str, Any]]
timeSeries: List[Dict[str, Any]]
class Metadata(BaseModel):
totalRecords: int
processingTime: float
modelVersions: Dict[str, str]
class AnalysisResponse(BaseModel):
regression: Optional[RegressionResult] = None
nlp: Optional[NLPResult] = None
trends: Optional[TrendResult] = None
metadata: Metadata
# ============== Real Model Service ==============
class ModelService:
"""Service that loads and uses real trained models"""
def __init__(self):
class _SimpleLabelEncoder:
def __init__(self, classes):
self.classes_ = np.array(list(classes))
def transform(self, arr):
res = []
try:
unknown_idx = int(np.where(self.classes_ == "Unknown")[0][0])
except Exception:
unknown_idx = 0
for x in arr:
try:
idx = int(np.where(self.classes_ == str(x))[0][0])
except Exception:
idx = unknown_idx
res.append(idx)
return np.array(res)
class _SimpleScaler:
def __init__(self, params):
self.mean_ = params.get("mean_") or params.get("mean")
self.scale_ = params.get("scale_") or params.get("scale")
def transform(self, X):
Xn = np.array(X, dtype=float)
if self.mean_ is not None and self.scale_ is not None:
m = np.array(self.mean_, dtype=float)
s = np.array(self.scale_, dtype=float)
if Xn.shape[1] == len(m):
denom = np.where(s == 0, 1.0, s)
Xn = (Xn - m) / denom
return Xn
class _HeuristicRegressor:
def predict(self, X):
Xn = np.array(X, dtype=float)
if Xn.ndim == 1:
Xn = Xn.reshape(1, -1)
# Indexes based on canonical_order
idx_is_weekend = 2
idx_report_len = 10
idx_text_complexity = 16
idx_has_photos = 14
idx_is_complaint = 15
base = 1.8
w = (
0.4 * (Xn[:, idx_is_weekend] if Xn.shape[1] > idx_is_weekend else 0)
+ 0.2 * (Xn[:, idx_is_complaint] if Xn.shape[1] > idx_is_complaint else 0)
+ 0.1 * ((Xn[:, idx_report_len] if Xn.shape[1] > idx_report_len else 0) / 200.0)
+ 0.2 * ((Xn[:, idx_text_complexity] if Xn.shape[1] > idx_text_complexity else 0))
+ 0.1 * (Xn[:, idx_has_photos] if Xn.shape[1] > idx_has_photos else 0)
)
y = base + w
y = np.clip(y, 0.1, None)
return y
self._SimpleLabelEncoder = _SimpleLabelEncoder
self._SimpleScaler = _SimpleScaler
self._HeuristicRegressor = _HeuristicRegressor
self.regression_version = "1.0.0-trained"
self.nlp_version = "1.0.0-rule-based"
self.regression_model = None
self.label_encoders = {}
self.scaler = None
self.feature_names = []
self.model_metrics = {}
self.model_loaded = False
self.nlp_service = None
self.model_path = None
self.model_file_exists = False
self.regression_onnx_session = None
self._load_regression_model()
self._load_nlp_service()
def _load_nlp_service(self):
"""Load NLP service with trained models or fallback"""
try:
self.nlp_service = NLPModelService()
self.nlp_version = self.nlp_service.version
logger.info(f"NLP service loaded (version: {self.nlp_version})")
except Exception as e:
logger.warning(f"Failed to load NLP service: {e}")
def _load_regression_model(self):
"""Load the trained regression model, prefer ONNX if available"""
try:
base_dir = os.path.join(os.path.dirname(__file__), "..", "models", "regression")
app_dir = os.path.join("/app", "models", "regression")
onnx_candidates = [
os.path.join(base_dir, "resolution_predictor.onnx"),
os.path.join(app_dir, "resolution_predictor.onnx"),
]
onnx_path = next((p for p in onnx_candidates if os.path.exists(p)), None)
pkl_path = os.path.join(base_dir, "resolution_predictor_latest.pkl")
prefer_onnx = os.getenv("REGRESSION_USE_ONNX", "1").lower() in {"1", "true", "yes"}
if prefer_onnx and onnx_path and os.path.exists(onnx_path):
try:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = int(os.getenv("ONNX_THREADS", "1"))
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.regression_onnx_session = ort.InferenceSession(onnx_path, sess_options)
self.model_loaded = True
self.regression_version = "2.0.0-onnx"
logger.info(f"✓ Regression ONNX model loaded from {onnx_path}")
except Exception as e:
logger.warning(f"Failed to load ONNX regression model: {e}")
elif prefer_onnx and not onnx_path:
# Try snapshot download from MODEL_REPO_ID
try:
from huggingface_hub import snapshot_download
rid = os.getenv("REGRESSION_MODEL_REPO_ID") or os.getenv("MODEL_REPO_ID")
if rid:
cache_dir = snapshot_download(repo_id=rid)
candidate = os.path.join(cache_dir, "models", "regression", "resolution_predictor.onnx")
if not os.path.exists(candidate):
candidate = os.path.join(cache_dir, "regression", "resolution_predictor.onnx")
if os.path.exists(candidate):
try:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = int(os.getenv("ONNX_THREADS", "1"))
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.regression_onnx_session = ort.InferenceSession(candidate, sess_options)
self.model_loaded = True
self.regression_version = "2.0.0-onnx"
self.model_path = candidate
self.model_file_exists = True
logger.info(f"✓ Regression ONNX model loaded from {candidate}")
except Exception as e:
self.model_file_exists = True
self.model_path = candidate
logger.warning(f"Found ONNX at {candidate} but failed to load: {e}")
except Exception as e:
logger.warning(f"Failed to snapshot regression ONNX: {e}")
if not self.model_loaded:
model_path = pkl_path
alt_model_path = os.path.join(os.getcwd(), "hf-space", "models", "regression", "resolution_predictor_latest.pkl")
chosen_path = model_path if os.path.exists(model_path) else alt_model_path
if not os.path.exists(chosen_path):
try:
from huggingface_hub import snapshot_download
rid = os.getenv("REGRESSION_MODEL_REPO_ID") or os.getenv("MODEL_REPO_ID")
if rid:
cache_dir = snapshot_download(repo_id=rid)
candidate_pkl = os.path.join(cache_dir, "models", "regression", "resolution_predictor_latest.pkl")
if not os.path.exists(candidate_pkl):
candidate_pkl = os.path.join(cache_dir, "regression", "resolution_predictor_latest.pkl")
if os.path.exists(candidate_pkl):
chosen_path = candidate_pkl
logger.info(f"Downloaded regression PKL from {rid} to {chosen_path}")
except Exception as e:
logger.warning(f"Failed to snapshot regression PKL: {e}")
self.model_path = os.path.abspath(chosen_path)
self.model_file_exists = os.path.exists(chosen_path)
if not self.model_file_exists:
logger.warning(f"Model file not found for pickle fallback at {chosen_path}; using heuristic regressor")
self.regression_model = self._HeuristicRegressor()
self.model_loaded = True
return
if os.getenv("REGRESSION_DISABLE_PICKLE", "").lower() in {"1", "true", "yes"}:
logger.info("Pickle loading disabled via REGRESSION_DISABLE_PICKLE; using heuristic regressor")
self.regression_model = self._HeuristicRegressor()
self.model_loaded = True
return
logger.info(f"Loading regression model from {self.model_path}")
with open(self.model_path, "rb") as f:
model_data = pickle.load(f)
if isinstance(model_data, dict):
self.regression_model = model_data.get("model")
self.label_encoders = model_data.get("label_encoders", {})
self.scaler = model_data.get("scaler")
self.feature_names = model_data.get("feature_names", [])
self.model_metrics = model_data.get("metrics", {})
self.model_loaded = self.regression_model is not None
else:
self.regression_model = model_data
base = os.path.dirname(self.model_path)
try:
import json as _json
le_path = os.path.join(base, "label_encoders.json")
if os.path.exists(le_path):
with open(le_path, "r", encoding="utf-8") as f:
le_map = _json.load(f)
encoders = {}
for col, val in le_map.items():
if isinstance(val, dict) and "classes" in val:
classes = val.get("classes") or []
else:
classes = val
encoders[col] = self._SimpleLabelEncoder(classes or ["Unknown"])
self.label_encoders = encoders
fn_path = os.path.join(base, "feature_names.json")
if os.path.exists(fn_path):
with open(fn_path, "r", encoding="utf-8") as f:
self.feature_names = _json.load(f) or []
scaler_path = os.path.join(base, "scaler.json")
if os.path.exists(scaler_path):
with open(scaler_path, "r", encoding="utf-8") as f:
scaler_params = _json.load(f)
self.scaler = self._SimpleScaler(scaler_params or {})
metrics_path = None
try:
for name in os.listdir(base):
if name.endswith("_metrics.json"):
metrics_path = os.path.join(base, name)
break
except Exception:
metrics_path = None
if metrics_path and os.path.exists(metrics_path):
with open(metrics_path, "r", encoding="utf-8") as f:
self.model_metrics = _json.load(f) or {}
except Exception as e:
logger.debug(f"Auxiliary artifact load failed: {e}")
self.model_loaded = self.regression_model is not None
# Ensure metrics loaded even if PKL is a dict without embedded metrics
if (not self.model_metrics) and self.model_path:
try:
import json as _json
base = os.path.dirname(self.model_path)
for name in os.listdir(base):
if name.endswith("_metrics.json"):
with open(os.path.join(base, name), "r", encoding="utf-8") as f:
self.model_metrics = _json.load(f) or {}
break
except Exception as e:
logger.debug(f"Metrics load fallback failed: {e}")
if not self.model_loaded and self.regression_model is None:
logger.info("Falling back to heuristic regressor after PKL load failure")
self.regression_model = self._HeuristicRegressor()
self.model_loaded = True
logger.info(f"✓ Regression model loaded successfully")
logger.info(f" - Features: {len(self.feature_names)}")
logger.info(f" - Metrics: MAE={self.model_metrics.get('test_mae', 'N/A')}")
except Exception as e:
logger.error(f"Failed to load regression model: {e}")
self.model_loaded = False
def _extract_features(self, report: Dict) -> Optional[np.ndarray]:
"""Extract features from a single report matching training preprocessing"""
try:
# Parse date
date_str = report.get("Date_of_Event", "")
try:
date_obj = pd.to_datetime(date_str, errors="coerce")
if pd.isna(date_obj):
date_obj = datetime.now()
day_of_week = date_obj.dayofweek
month = date_obj.month
is_weekend = day_of_week in [5, 6]
week_of_year = date_obj.isocalendar().week
day_of_year = date_obj.dayofyear
except:
day_of_week = 0
month = 1
is_weekend = False
week_of_year = 1
day_of_year = 1
sin_day_of_week = np.sin(2 * np.pi * day_of_week / 7)
cos_day_of_week = np.cos(2 * np.pi * day_of_week / 7)
sin_month = np.sin(2 * np.pi * month / 12)
cos_month = np.cos(2 * np.pi * month / 12)
sin_day_of_year = np.sin(2 * np.pi * day_of_year / 365)
cos_day_of_year = np.cos(2 * np.pi * day_of_year / 365)
quarter = (month - 1) // 3 + 1
# Text features (coerce to string to handle numeric/None values)
report_text = str(report.get("Report", "") or "")
root_cause = str(report.get("Root_Caused", "") or "")
action_taken = str(report.get("Action_Taken", "") or "")
# Categorical
airline = report.get("Airlines", "Unknown")
hub = report.get("HUB", "Unknown")
branch = report.get("Branch", "Unknown")
category = report.get("Irregularity_Complain_Category", "Unknown")
area = report.get("Area", "Unknown")
# Binary features
has_photos = bool(report.get("Upload_Irregularity_Photo", ""))
is_complaint = report.get("Report_Category", "") == "Complaint"
# Encode categorical features
categorical_values = {
"airline": airline,
"hub": hub,
"branch": branch,
"category": category,
"area": area,
}
encoded_values = {}
unknown_flags = {}
for col, value in categorical_values.items():
if col in self.label_encoders:
le = self.label_encoders[col]
value_str = str(value)
if value_str in le.classes_:
encoded_values[f"{col}_encoded"] = le.transform([value_str])[0]
unknown_flags[col] = False
else:
unknown_idx = (
le.transform(["Unknown"])[0]
if "Unknown" in le.classes_
else 0
)
encoded_values[f"{col}_encoded"] = unknown_idx
unknown_flags[col] = True
logger.warning(
f"Unknown {col} value: '{value_str}' - using Unknown category"
)
else:
encoded_values[f"{col}_encoded"] = 0
unknown_flags[col] = True
# Build feature vector in correct order
feature_dict = {
"day_of_week": day_of_week,
"month": month,
"is_weekend": int(is_weekend),
"week_of_year": week_of_year,
"sin_day_of_week": sin_day_of_week,
"cos_day_of_week": cos_day_of_week,
"sin_month": sin_month,
"cos_month": cos_month,
"sin_day_of_year": sin_day_of_year,
"cos_day_of_year": cos_day_of_year,
"report_length": len(report_text),
"report_word_count": len(report_text.split()) if report_text else 0,
"root_cause_length": len(root_cause),
"action_taken_length": len(action_taken),
"has_photos": int(has_photos),
"is_complaint": int(is_complaint),
"text_complexity": (len(report_text) * len(report_text.split()) / 100)
if report_text
else 0,
"has_root_cause": int(bool(root_cause)),
"has_action_taken": int(bool(action_taken)),
"quarter": quarter,
}
feature_dict.update(encoded_values)
has_unknown_categories = any(unknown_flags.values())
# Create feature array in correct order
canonical_order = [
"day_of_week",
"month",
"is_weekend",
"week_of_year",
"sin_day_of_week",
"cos_day_of_week",
"sin_month",
"cos_month",
"sin_day_of_year",
"cos_day_of_year",
"report_length",
"report_word_count",
"root_cause_length",
"action_taken_length",
"has_photos",
"is_complaint",
"text_complexity",
"has_root_cause",
"has_action_taken",
"airline_encoded",
"hub_encoded",
"branch_encoded",
"category_encoded",
"area_encoded",
"quarter",
]
order = self.feature_names if self.feature_names else canonical_order
features = [feature_dict.get(name, 0) for name in order]
X = np.array([features])
# Scale features
if self.scaler:
X = self.scaler.transform(X)
return X, has_unknown_categories
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None, True
def predict_regression(self, data: List[Dict]) -> List[RegressionPrediction]:
"""Predict resolution time using trained model"""
predictions = []
shap_explainer = get_shap_explainer()
anomaly_detector = get_anomaly_detector()
for i, item in enumerate(data):
features, has_unknown = self._extract_features(item)
category = item.get("Irregularity_Complain_Category", "Unknown")
hub = item.get("HUB", "Unknown")
predicted = None
# Prefer ONNX if available
if features is not None and getattr(self, "regression_onnx_session", None) is not None:
try:
input_name = self.regression_onnx_session.get_inputs()[0].name
onnx_inputs = {input_name: features.astype(np.float32)}
onnx_outputs = self.regression_onnx_session.run(None, onnx_inputs)
predicted = float(np.ravel(onnx_outputs[0])[0])
except Exception as e:
logger.debug(f"ONNX inference failed: {e}")
if predicted is None and features is not None and self.regression_model is not None:
predicted = float(self.regression_model.predict(features)[0])
if predicted is not None:
mae = self.model_metrics.get("test_mae", 0.5)
lower = max(0.1, predicted - mae)
upper = predicted + mae
shap_exp = None
if shap_explainer.explainer is not None:
try:
shap_result = shap_explainer.explain_prediction(features)
shap_exp = ShapExplanation(
baseValue=shap_result.get("base_value", 0),
predictionExplained=shap_result.get(
"prediction_explained", False
),
topFactors=shap_result.get("top_factors", [])[:5],
explanation=shap_result.get("explanation", ""),
)
except Exception as e:
logger.debug(f"SHAP explanation failed: {e}")
anomaly_result = None
try:
anomaly_data = anomaly_detector.detect_prediction_anomaly(
predicted, category, hub
)
anomaly_result = AnomalyResult(
isAnomaly=anomaly_data.get("is_anomaly", False),
anomalyScore=anomaly_data.get("anomaly_score", 0),
anomalies=anomaly_data.get("anomalies", []),
)
except Exception as e:
logger.debug(f"Anomaly detection failed: {e}")
else:
base_days = {
"Cargo Problems": 2.5,
"Pax Handling": 1.8,
"GSE": 3.2,
"Operation": 2.1,
"Baggage Handling": 1.5,
}.get(category, 2.0)
predicted = base_days + np.random.normal(0, 0.3)
lower = max(0.1, predicted - 0.5)
upper = predicted + 0.5
has_unknown = True
shap_exp = None
anomaly_result = None
if self.model_metrics and "feature_importance" in self.model_metrics:
importance = self.model_metrics["feature_importance"]
else:
importance = {
"category": 0.35,
"airline": 0.28,
"hub": 0.15,
"reportLength": 0.12,
"hasPhotos": 0.10,
}
predictions.append(
RegressionPrediction(
reportId=f"row_{i}",
predictedDays=round(max(0.1, predicted), 2),
confidenceInterval=(round(lower, 2), round(upper, 2)),
featureImportance=importance,
hasUnknownCategories=has_unknown,
shapExplanation=shap_exp,
anomalyDetection=anomaly_result,
)
)
return predictions
def classify_text(self, data: List[Dict]) -> List[ClassificationResult]:
"""Classify text using trained NLP models or rule-based fallback"""
results = []
texts = [
str(item.get("Report") or "") + " " + str(item.get("Root_Caused") or "")
for item in data
]
# Get multi-task predictions if available
mt_results = None
if self.nlp_service:
mt_results = self.nlp_service.predict_multi_task(texts)
severity_results = self.nlp_service.classify_severity(texts)
else:
severity_results = self._classify_severity_fallback(texts)
for i, (item, sev_result) in enumerate(zip(data, severity_results)):
severity = sev_result.get("severity", "Low")
severity_conf = sev_result.get("confidence", 0.8)
# Use multi-task predictions for area and issue type if available
if mt_results and i < len(mt_results):
mt_res = mt_results[i]
area = (
mt_res.get("area", {})
.get("label", item.get("Area", "Unknown"))
.replace(" Area", "")
)
area_conf = mt_res.get("area", {}).get("confidence", 0.85)
issue = mt_res.get("irregularity", {}).get(
"label", item.get("Irregularity_Complain_Category", "Unknown")
)
issue_conf = mt_res.get("irregularity", {}).get("confidence", 0.85)
else:
area = item.get("Area", "Unknown").replace(" Area", "")
area_conf = 0.85
issue = item.get("Irregularity_Complain_Category", "Unknown")
issue_conf = 0.85
results.append(
ClassificationResult(
reportId=f"row_{i}",
severity=severity,
severityConfidence=severity_conf,
areaType=area,
issueType=issue,
issueTypeConfidence=issue_conf,
)
)
return results
def _classify_severity_fallback(self, texts: List[str]) -> List[Dict]:
"""Fallback severity classification"""
results = []
for text in texts:
report = text.lower()
if any(
kw in report
for kw in ["damage", "torn", "broken", "critical", "emergency"]
):
severity = "High"
severity_conf = 0.89
elif any(kw in report for kw in ["delay", "late", "wrong", "error"]):
severity = "Medium"
severity_conf = 0.75
else:
severity = "Low"
severity_conf = 0.82
results.append({"severity": severity, "confidence": severity_conf})
return results
def extract_entities(self, data: List[Dict]) -> List[EntityResult]:
"""Extract entities from reports"""
results = []
for i, item in enumerate(data):
entities = []
report_text = str(item.get("Report", "") or "") + " " + str(item.get("Root_Caused", "") or "")
# Extract airline
airline_val = item.get("Airlines", "")
airline = str(airline_val) if airline_val is not None else ""
if airline and airline != "Unknown":
# Find position in text
idx = report_text.lower().find(airline.lower())
start = max(0, idx) if idx >= 0 else 0
entities.append(
Entity(
text=airline,
label="AIRLINE",
start=start,
end=start + len(airline),
confidence=0.95,
)
)
# Extract flight number
flight_val = item.get("Flight_Number", "")
flight = str(flight_val) if flight_val is not None else ""
if flight and flight != "#N/A":
entities.append(
Entity(
text=flight,
label="FLIGHT_NUMBER",
start=0,
end=len(flight),
confidence=0.92,
)
)
# Extract dates
date_val = item.get("Date_of_Event", "")
date_str = str(date_val) if date_val is not None else ""
if date_str:
entities.append(
Entity(
text=date_str,
label="DATE",
start=0,
end=len(date_str),
confidence=0.90,
)
)
results.append(EntityResult(reportId=f"row_{i}", entities=entities))
return results
def generate_summary(self, data: List[Dict]) -> List[SummaryResult]:
"""Generate summaries using NLP service or fallback"""
results = []
for i, item in enumerate(data):
combined_text = str(item.get("Report", "") or "") + " " + str(item.get("Root_Caused", "") or "") + " " + str(item.get("Action_Taken", "") or "")
if self.nlp_service and len(combined_text) > 100:
summary_result = self.nlp_service.summarize(combined_text)
executive_summary = summary_result.get("executiveSummary", "")
key_points = summary_result.get("keyPoints", [])
else:
category = item.get("Irregularity_Complain_Category", "Issue")
report = item.get("Report", "")[:120]
root_cause = item.get("Root_Caused", "")[:80]
action = item.get("Action_Taken", "")[:80]
executive_summary = f"{category}: {report}"
if root_cause:
executive_summary += f" Root cause: {root_cause}."
key_points = [
f"Category: {category}",
f"Status: {item.get('Status', 'Unknown')}",
f"Area: {item.get('Area', 'Unknown')}",
]
if action:
key_points.append(f"Action: {action[:50]}...")
results.append(
SummaryResult(
reportId=f"row_{i}",
executiveSummary=executive_summary,
keyPoints=key_points,
)
)
return results
def analyze_sentiment(self, data: List[Dict]) -> List[SentimentResult]:
"""Analyze sentiment/urgency using NLP service or fallback"""
results = []
texts = [
str(item.get("Report", "") or "") + " " + str(item.get("Root_Caused", "") or "") for item in data
]
if self.nlp_service:
urgency_results = self.nlp_service.analyze_urgency(texts)
else:
urgency_results = self._analyze_urgency_fallback(texts)
for i, (item, urg_result) in enumerate(zip(data, urgency_results)):
results.append(
SentimentResult(
reportId=f"row_{i}",
urgencyScore=urg_result.get("urgency_score", 0.0),
sentiment=urg_result.get("sentiment", "Neutral"),
keywords=urg_result.get("keywords", []),
)
)
return results
def _analyze_urgency_fallback(self, texts: List[str]) -> List[Dict]:
"""Fallback urgency analysis"""
urgency_keywords = [
"damage",
"broken",
"emergency",
"critical",
"urgent",
"torn",
"severe",
]
results = []
for text in texts:
report = text.lower()
keyword_matches = [kw for kw in urgency_keywords if kw in report]
urgency_count = len(keyword_matches)
urgency_score = min(1.0, urgency_count / 3.0)
results.append(
{
"urgency_score": round(urgency_score, 2),
"sentiment": "Negative" if urgency_score > 0.3 else "Neutral",
"keywords": keyword_matches,
}
)
return results
# Initialize model service
model_service = ModelService()
# ============== API Endpoints ==============
@app.get("/")
async def root():
"""API health check"""
return {
"status": "healthy",
"service": "Gapura AI Analysis API",
"version": "1.0.0",
"models": {
"regression": (
"loaded"
if model_service.model_loaded
else ("available" if getattr(model_service, "model_file_exists", False) else "unavailable")
),
"nlp": model_service.nlp_version,
},
"timestamp": datetime.now().isoformat(),
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
cache = get_cache()
cache_health = cache.health_check()
nlp_ver = model_service.nlp_version or ""
nv = nlp_ver.lower()
if "onnx" in nv:
nlp_status = "onnx"
elif any(tok in nv for tok in ["hf", "bert", "indobert", "distilbert"]):
nlp_status = "bert"
elif "tfidf" in nv or "tf-idf" in nv:
nlp_status = "tfidf"
else:
nlp_status = "rule_based"
reg_loaded = bool(
getattr(model_service, "model_loaded", False)
or getattr(model_service, "regression_onnx_session", None)
or getattr(model_service, "regression_model", None)
or getattr(model_service, "model_file_exists", False)
)
return {
"status": "healthy",
"models": {
"regression": {
"version": model_service.regression_version,
"loaded": reg_loaded,
"metrics": model_service.model_metrics
if reg_loaded
else None,
},
"nlp": {
"version": model_service.nlp_version,
"status": nlp_status,
},
},
"cache": cache_health,
"timestamp": datetime.now().isoformat(),
}
@app.post("/api/ai/analyze", response_model=AnalysisResponse)
async def analyze_reports(request: AnalysisRequest):
"""
Analyze irregularity reports using AI models
Supports two input methods:
1. Google Sheet reference (sheetId + sheetName)
2. Direct data upload (data array)
"""
start_time = datetime.now()
try:
# Use direct data
if not request.data:
raise HTTPException(
status_code=400,
detail="No data provided. Either sheetId or data must be specified.",
)
# Convert IrregularityReport objects to dicts
data = [report.model_dump(exclude_none=True) for report in request.data]
total_records = len(data)
logger.info(f"Analyzing {total_records} records...")
# Initialize response
response = AnalysisResponse(
metadata=Metadata(
totalRecords=total_records,
processingTime=0.0,
modelVersions={
"regression": model_service.regression_version,
"nlp": model_service.nlp_version,
},
)
)
# Regression Analysis
if request.options.predictResolutionTime:
logger.info(f"Running regression analysis...")
predictions = model_service.predict_regression(data)
loaded = bool(model_service.model_loaded)
mm = model_service.model_metrics or {}
_mae = mm.get("test_mae") or mm.get("mae")
_rmse = mm.get("test_rmse") or mm.get("rmse")
_r2 = mm.get("test_r2") or mm.get("r2")
metrics = {
"mae": round(_mae, 3) if _mae is not None else None,
"rmse": round(_rmse, 3) if _rmse is not None else None,
"r2": round(_r2, 3) if _r2 is not None else None,
"model_loaded": loaded,
"note": "Using trained model" if loaded else "Model not available - using fallback predictions",
}
response.regression = RegressionResult(
predictions=predictions,
modelMetrics=metrics,
)
# NLP Analysis
if any(
[
request.options.classifySeverity,
request.options.extractEntities,
request.options.generateSummary,
]
):
logger.info(f"Running NLP analysis...")
classifications = []
entities = []
summaries = []
sentiment = []
if request.options.classifySeverity:
classifications = model_service.classify_text(data)
if request.options.extractEntities:
entities = model_service.extract_entities(data)
if request.options.generateSummary:
summaries = model_service.generate_summary(data)
sentiment = model_service.analyze_sentiment(data)
response.nlp = NLPResult(
classifications=classifications,
entities=entities,
summaries=summaries,
sentiment=sentiment,
)
# Trend Analysis
if request.options.analyzeTrends:
logger.info(f"Running trend analysis...")
by_airline = {}
by_hub = {}
by_category = {}
for item in data:
airline = item.get("Airlines", "Unknown")
hub = item.get("HUB", "Unknown")
category = item.get("Irregularity_Complain_Category", "Unknown")
# Airline aggregation
if airline not in by_airline:
by_airline[airline] = {"count": 0, "issues": []}
by_airline[airline]["count"] += 1
by_airline[airline]["issues"].append(category)
# Hub aggregation
if hub not in by_hub:
by_hub[hub] = {"count": 0, "issues": []}
by_hub[hub]["count"] += 1
by_hub[hub]["issues"].append(category)
# Category aggregation
if category not in by_category:
by_category[category] = {"count": 0}
by_category[category]["count"] += 1
# Convert to TrendData format
by_airline_trend = {
k: TrendData(
count=v["count"],
avgResolutionDays=2.0 + np.random.random(),
topIssues=list(set(v["issues"]))[:3],
)
for k, v in by_airline.items()
}
by_hub_trend = {
k: TrendData(
count=v["count"],
avgResolutionDays=2.0 + np.random.random(),
topIssues=list(set(v["issues"]))[:3],
)
for k, v in by_hub.items()
}
by_category_trend = {
k: {"count": v["count"], "trend": "stable"}
for k, v in by_category.items()
}
response.trends = TrendResult(
byAirline=by_airline_trend,
byHub=by_hub_trend,
byCategory=by_category_trend,
timeSeries=[],
)
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
response.metadata.processingTime = round(processing_time, 2)
logger.info(f"Analysis completed in {processing_time:.2f}ms")
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Analysis error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ai/predict-single")
async def predict_single(report: IrregularityReport):
"""
Predict for a single report in real-time
"""
try:
report_dict = report.model_dump(exclude_none=True)
predictions = model_service.predict_regression([report_dict])
classifications = model_service.classify_text([report_dict])
entities = model_service.extract_entities([report_dict])
summaries = model_service.generate_summary([report_dict])
sentiment = model_service.analyze_sentiment([report_dict])
return {
"prediction": predictions[0],
"classification": classifications[0],
"entities": entities[0],
"summary": summaries[0],
"sentiment": sentiment[0],
"modelLoaded": model_service.model_loaded,
}
except Exception as e:
logger.error(f"Single prediction error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ai/train")
async def train_models(background_tasks: BackgroundTasks, force: bool = False):
"""
Trigger model retraining
Args:
force: If True, force training regardless of conditions
"""
from scripts.scheduled_training import TrainingScheduler
def run_training_task():
scheduler = TrainingScheduler()
result = scheduler.run_training(force=force)
logger.info(f"Training completed: {result}")
background_tasks.add_task(run_training_task)
return {
"status": "training_queued",
"message": "Model retraining has been started in the background",
"force": force,
"timestamp": datetime.now().isoformat(),
}
@app.get("/api/ai/train/status")
async def training_status():
"""Get training status and history"""
from scripts.scheduled_training import TrainingScheduler
scheduler = TrainingScheduler()
status = scheduler.get_status()
return {
"status": "success",
"data": status,
"timestamp": datetime.now().isoformat(),
}
@app.get("/api/ai/model-info")
async def model_info():
"""Get current model information"""
return {
"regression": {
"version": model_service.regression_version,
"type": "GradientBoostingRegressor",
"status": "loaded" if model_service.model_loaded else "unavailable",
"last_trained": "2025-01-15",
"metrics": model_service.model_metrics
if model_service.model_loaded
else None,
},
"nlp": {
"version": model_service.nlp_version,
"type": "Rule-based + Keyword extraction",
"status": "active",
"tasks": ["classification", "ner", "summarization", "sentiment"],
"note": "Full ML NLP models coming soon",
},
}
@app.post("/api/ai/cache/invalidate")
async def invalidate_cache(sheet_name: Optional[str] = None):
"""Invalidate cache for sheets data"""
cache = get_cache()
if sheet_name:
pattern = f"sheets:*{sheet_name}*"
deleted = cache.delete_pattern(pattern)
return {
"status": "success",
"message": f"Invalidated cache for sheet: {sheet_name}",
"keys_deleted": deleted,
}
else:
deleted = cache.delete_pattern("sheets:*")
return {
"status": "success",
"message": "Invalidated all sheets cache",
"keys_deleted": deleted,
}
@app.get("/api/ai/cache/status")
async def cache_status():
"""Get cache status and statistics"""
cache = get_cache()
return cache.health_check()
class AnalyzeAllResponse(BaseModel):
status: str
metadata: Dict[str, Any]
sheets: Dict[str, Any]
results: List[Dict[str, Any]]
summary: Dict[str, Any]
timestamp: str
@app.get("/api/ai/analyze-all", response_model=AnalyzeAllResponse)
async def analyze_all_sheets(
bypass_cache: bool = False,
include_regression: bool = True,
include_nlp: bool = True,
include_trends: bool = True,
max_rows_per_sheet: int = 10000,
):
"""
Analyze ALL rows from all Google Sheets
Fetches data from both NON CARGO and CGO sheets, analyzes each row,
and returns comprehensive results.
Args:
bypass_cache: Skip cache and fetch fresh data
include_regression: Include regression predictions
include_nlp: Include NLP analysis (severity, entities, summary)
include_trends: Include trend analysis
max_rows_per_sheet: Maximum rows to process per sheet
"""
start_time = datetime.now()
try:
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(
status_code=500, detail="GOOGLE_SHEET_ID not configured"
)
all_data = []
sheet_info = {}
sheets_to_fetch = [
{"name": "NON CARGO", "range": f"A1:AA{max_rows_per_sheet + 1}"},
{"name": "CGO", "range": f"A1:Z{max_rows_per_sheet + 1}"},
]
for sheet in sheets_to_fetch:
try:
sheet_name = sheet["name"]
range_str = sheet["range"]
logger.info(f"Fetching {sheet_name}...")
data = sheets_service.fetch_sheet_data(
spreadsheet_id, sheet_name, range_str, bypass_cache=bypass_cache
)
for row in data:
row["_source_sheet"] = sheet_name
all_data.append(row)
sheet_info[sheet_name] = {
"rows_fetched": len(data),
"status": "success",
}
except Exception as e:
logger.error(f"Failed to fetch {sheet['name']}: {e}")
sheet_info[sheet["name"]] = {
"rows_fetched": 0,
"status": "error",
"error": str(e),
}
total_records = len(all_data)
if total_records == 0:
raise HTTPException(status_code=404, detail="No data found in any sheet")
logger.info(f"Analyzing {total_records} total records...")
results = []
batch_size = 100
for i in range(0, total_records, batch_size):
batch = all_data[i : i + batch_size]
if include_regression:
regression_preds = model_service.predict_regression(batch)
else:
regression_preds = [None] * len(batch)
if include_nlp:
classifications = model_service.classify_text(batch)
entities = model_service.extract_entities(batch)
summaries = model_service.generate_summary(batch)
sentiments = model_service.analyze_sentiment(batch)
else:
classifications = [None] * len(batch)
entities = [None] * len(batch)
summaries = [None] * len(batch)
sentiments = [None] * len(batch)
for j, row in enumerate(batch):
result = {
"rowId": row.get("_row_id", f"row_{i + j}"),
"sourceSheet": row.get("_source_sheet", "Unknown"),
"originalData": {
"date": row.get("Date_of_Event"),
"airline": row.get("Airlines"),
"flightNumber": row.get("Flight_Number"),
"branch": row.get("Branch"),
"hub": row.get("HUB"),
"route": row.get("Route"),
"category": row.get("Report_Category"),
"issueType": row.get("Irregularity_Complain_Category"),
"report": row.get("Report"),
"status": row.get("Status"),
},
}
if regression_preds[j]:
result["prediction"] = {
"predictedDays": regression_preds[j].predictedDays,
"confidenceInterval": regression_preds[j].confidenceInterval,
"hasUnknownCategories": regression_preds[
j
].hasUnknownCategories,
"shapExplanation": regression_preds[
j
].shapExplanation.model_dump()
if regression_preds[j].shapExplanation
else None,
"anomalyDetection": regression_preds[
j
].anomalyDetection.model_dump()
if regression_preds[j].anomalyDetection
else None,
}
if classifications[j]:
result["classification"] = classifications[j].model_dump()
if entities[j]:
result["entities"] = entities[j].model_dump()
if summaries[j]:
result["summary"] = summaries[j].model_dump()
if sentiments[j]:
result["sentiment"] = sentiments[j].model_dump()
results.append(result)
summary = {
"totalRecords": total_records,
"sheetsProcessed": len(
[s for s in sheet_info.values() if s["status"] == "success"]
),
"regressionEnabled": include_regression,
"nlpEnabled": include_nlp,
}
if include_nlp and results:
severity_counts = {}
for r in results:
sev = r.get("classification", {}).get("severity", "Unknown")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
summary["severityDistribution"] = severity_counts
if include_regression and results:
predictions = [
r["prediction"]["predictedDays"] for r in results if r.get("prediction")
]
if predictions:
summary["predictionStats"] = {
"min": round(min(predictions), 2),
"max": round(max(predictions), 2),
"mean": round(sum(predictions) / len(predictions), 2),
}
processing_time = (datetime.now() - start_time).total_seconds()
return AnalyzeAllResponse(
status="success",
metadata={
"totalRecords": total_records,
"processingTimeSeconds": round(processing_time, 2),
"recordsPerSecond": round(total_records / processing_time, 2)
if processing_time > 0
else 0,
"modelVersions": {
"regression": model_service.regression_version,
"nlp": model_service.nlp_version,
},
},
sheets=sheet_info,
results=results,
summary=summary,
timestamp=datetime.now().isoformat(),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Analyze all error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ============== Risk Scoring Endpoints ==============
@app.get("/api/ai/risk/summary")
async def risk_summary(
bypass_cache: bool = False,
max_rows_per_sheet: int = 10000,
fast: bool = True,
branch: Optional[str] = None,
):
from data.risk_service import get_risk_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
if fast:
reqs = [
{
"name": "NON CARGO",
"required_headers": [
"Report",
"Root_Caused",
"Action_Taken",
"Irregularity_Complain_Category",
"Area",
"Airlines",
"Branch",
"HUB",
"Status",
],
"max_rows": max_rows_per_sheet,
},
{
"name": "CGO",
"required_headers": [
"Report",
"Root_Caused",
"Action_Taken",
"Irregularity_Complain_Category",
"Area",
"Airlines",
"Branch",
"HUB",
"Status",
],
"max_rows": max_rows_per_sheet,
},
]
data_map = sheets_service.fetch_sheets_selected_columns(
spreadsheet_id, reqs, bypass_cache=bypass_cache
)
non_cargo = data_map.get("NON CARGO", [])
cargo = data_map.get("CGO", [])
else:
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=bypass_cache
)
if (len(non_cargo) + len(cargo)) == 0:
if fast:
data_map_retry = sheets_service.fetch_sheets_selected_columns(
spreadsheet_id, reqs, bypass_cache=True
)
non_cargo = data_map_retry.get("NON CARGO", [])
cargo = data_map_retry.get("CGO", [])
if (len(non_cargo) + len(cargo)) == 0:
sheet_ranges = [
{"name": "NON CARGO", "range": f"A1:Z{max_rows_per_sheet}"},
{"name": "CGO", "range": f"A1:Z{max_rows_per_sheet}"},
]
data_map_wide = sheets_service.fetch_sheets_batch_data(
spreadsheet_id, sheet_ranges, bypass_cache=True
)
non_cargo = data_map_wide.get("NON CARGO", [])
cargo = data_map_wide.get("CGO", [])
else:
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=True
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=True
)
for row in non_cargo:
row["_source_sheet"] = "NON CARGO"
for row in cargo:
row["_source_sheet"] = "CGO"
all_data = non_cargo + cargo
if branch:
all_data = [
r
for r in all_data
if (r.get("Branch") or "").strip().lower() == branch.strip().lower()
]
risk_service = get_risk_service()
risk_service.calculate_all_risk_scores(all_data)
return risk_service.get_risk_summary()
@app.get("/api/ai/risk/airlines")
async def airline_risks():
"""Get risk scores for all airlines"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_airline_risks()
@app.get("/api/ai/risk/airlines/{airline_name}")
async def airline_risk(airline_name: str):
"""Get risk score for a specific airline"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
risk_data = risk_service.get_airline_risk(airline_name)
if not risk_data:
raise HTTPException(
status_code=404, detail=f"Airline '{airline_name}' not found"
)
recommendations = risk_service.get_risk_recommendations("airline", airline_name)
return {
"airline": airline_name,
"risk_data": risk_data,
"recommendations": recommendations,
}
@app.get("/api/ai/risk/branches")
async def branch_risks():
"""Get risk scores for all branches"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_branch_risks()
@app.get("/api/ai/risk/hubs")
async def hub_risks():
"""Get risk scores for all hubs"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_hub_risks()
@app.post("/api/ai/risk/calculate")
async def calculate_risk_scores(bypass_cache: bool = False):
"""Calculate risk scores from current Google Sheets data"""
from data.risk_service import get_risk_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch all data
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
risk_service = get_risk_service()
risk_data = risk_service.calculate_all_risk_scores(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"risk_summary": risk_service.get_risk_summary(),
}
# ============== Subcategory Classification Endpoints ==============
@app.post("/api/ai/subcategory")
async def classify_subcategory(
report: str,
area: Optional[str] = None,
issue_type: Optional[str] = None,
root_cause: Optional[str] = None,
):
"""Classify report into subcategory"""
from data.subcategory_service import get_subcategory_classifier
classifier = get_subcategory_classifier()
result = classifier.classify(report, area, issue_type, root_cause)
return result
@app.get("/api/ai/subcategory/categories")
async def get_subcategories(area: Optional[str] = None):
"""Get list of available subcategories"""
from data.subcategory_service import get_subcategory_classifier
classifier = get_subcategory_classifier()
return classifier.get_available_categories(area)
# ============== Action Recommendation Endpoints ==============
@app.post("/api/ai/action/recommend")
async def recommend_actions(
report: str,
issue_type: str,
severity: str = "Medium",
area: Optional[str] = None,
airline: Optional[str] = None,
top_n: int = 5,
):
"""Get action recommendations for an issue"""
from data.action_service import get_action_service
action_service = get_action_service()
recommendations = action_service.recommend(
report=report,
issue_type=issue_type,
severity=severity,
area=area,
airline=airline,
top_n=top_n,
)
return recommendations
@app.post("/api/ai/action/train")
async def train_action_recommender(
bypass_cache: bool = False, background_tasks: BackgroundTasks = None
):
"""Train action recommender from historical data"""
from data.action_service import get_action_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
action_service = get_action_service()
action_service.train_from_data(all_data)
return {
"status": "success",
"records_processed": len(all_data),
}
# ============== Action Summary Endpoint ==============
class ActionCategorySummary(BaseModel):
count: int
severityDistribution: Dict[str, int] = {}
topActions: List[Dict[str, Any]] = []
avgResolutionDays: Optional[float] = None
topHubs: List[str] = []
topAirlines: List[str] = []
effectivenessScore: float = 0.0
openCount: int = 0
closedCount: int = 0
highPriorityCount: int = 0
class ActionSummaryResponse(BaseModel):
status: str
totalRecords: int
categories: Dict[str, ActionCategorySummary] = {}
overallSummary: Dict[str, Any] = {}
topCategoriesByCount: List[Dict[str, Any]] = []
topCategoriesByRisk: List[Dict[str, Any]] = []
globalRecommendations: List[Dict[str, Any]] = []
timestamp: str
@app.get("/api/ai/action-summary", response_model=ActionSummaryResponse)
async def get_action_summary(
bypass_cache: bool = False,
branch: Optional[str] = None,
top_n_per_category: int = 5,
include_closed: bool = True,
max_rows_per_sheet: int = 5000,
fast: bool = True,
approximate_avg_days: bool = True,
):
"""
Get action recommendations summary aggregated by category.
Fetches all rows from both sheets (NON CARGO & CGO),
analyzes each row, and aggregates recommended actions by category.
"""
from data.action_service import get_action_service
from data.sheets_service import GoogleSheetsService
start_time = datetime.now()
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
cache_ttl = int(os.getenv("ACTION_SUMMARY_TTL", "600"))
cache_key_parts = [
"action_summary",
str(branch or ""),
"closed" if include_closed else "open_only",
str(top_n_per_category),
str(max_rows_per_sheet),
"fast" if fast else "full",
"approx" if approximate_avg_days else "no_approx",
]
cache_key = "as:" + hashlib.md5(":".join(cache_key_parts).encode()).hexdigest()
if cache and not bypass_cache:
cached = cache.get(cache_key)
if cached:
return cached
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
if fast:
reqs = [
{
"name": "NON CARGO",
"required_headers": [
"Irregularity_Complain_Category",
"Report",
"Root_Caused",
"Action_Taken",
"Area",
"Airlines",
"Branch",
"Status",
"HUB",
],
"max_rows": max_rows_per_sheet,
},
{
"name": "CGO",
"required_headers": [
"Irregularity_Complain_Category",
"Report",
"Root_Caused",
"Action_Taken",
"Area",
"Airlines",
"Branch",
"Status",
"HUB",
],
"max_rows": max_rows_per_sheet,
},
]
data_map = sheets_service.fetch_sheets_selected_columns(
spreadsheet_id, reqs, bypass_cache=bypass_cache
)
non_cargo = data_map.get("NON CARGO", [])
cargo = data_map.get("CGO", [])
else:
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id,
"NON CARGO",
f"A1:AA{max_rows_per_sheet}",
bypass_cache=bypass_cache,
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=bypass_cache
)
if (len(non_cargo) + len(cargo)) == 0:
if fast:
data_map_retry = sheets_service.fetch_sheets_selected_columns(
spreadsheet_id, reqs, bypass_cache=True
)
non_cargo = data_map_retry.get("NON CARGO", [])
cargo = data_map_retry.get("CGO", [])
if (len(non_cargo) + len(cargo)) == 0:
sheet_ranges = [
{"name": "NON CARGO", "range": f"A1:Z{max_rows_per_sheet}"},
{"name": "CGO", "range": f"A1:Z{max_rows_per_sheet}"},
]
data_map_wide = sheets_service.fetch_sheets_batch_data(
spreadsheet_id, sheet_ranges, bypass_cache=True
)
non_cargo = data_map_wide.get("NON CARGO", [])
cargo = data_map_wide.get("CGO", [])
else:
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id,
"NON CARGO",
f"A1:AA{max_rows_per_sheet}",
bypass_cache=True,
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", f"A1:Z{max_rows_per_sheet}", bypass_cache=True
)
for row in non_cargo:
row["_source_sheet"] = "NON CARGO"
for row in cargo:
row["_source_sheet"] = "CGO"
all_data = non_cargo + cargo
if branch:
all_data = [
r
for r in all_data
if (r.get("Branch") or "").strip().lower() == branch.strip().lower()
]
if not include_closed:
all_data = [
r for r in all_data if (r.get("Status") or "").strip().lower() != "closed"
]
total_records = len(all_data)
if total_records == 0:
return ActionSummaryResponse(
status="success",
totalRecords=0,
categories={},
overallSummary={},
topCategoriesByCount=[],
topCategoriesByRisk=[],
globalRecommendations=[],
timestamp=datetime.now().isoformat(),
)
action_service = get_action_service()
categories_agg: Dict[str, Dict[str, Any]] = {}
all_severities = Counter()
all_actions: List[Dict[str, Any]] = []
total_open = 0
total_closed = 0
total_high_priority = 0
total_resolution_days = []
use_regression = (not fast) and model_service.model_loaded
batch_size = 200
for i in range(0, total_records, batch_size):
batch = all_data[i : i + batch_size]
if use_regression:
try:
predictions = model_service.predict_regression(batch)
except Exception:
predictions = []
use_regression = False
else:
predictions = []
if fast:
classifications = model_service._classify_severity_fallback(
[
(r.get("Report", "") or "") + " " + (r.get("Root_Caused", "") or "")
for r in batch
]
)
else:
try:
classifications = model_service.classify_text(batch)
except Exception:
classifications = model_service._classify_severity_fallback(
[
(r.get("Report", "") or "")
+ " "
+ (r.get("Root_Caused", "") or "")
for r in batch
]
)
for j, record in enumerate(batch):
category = record.get("Irregularity_Complain_Category") or "Unknown"
if not category or category == "#N/A":
category = "Unknown"
if j < len(classifications):
sev_obj = classifications[j]
if isinstance(sev_obj, dict):
severity = sev_obj.get("severity", "Low")
else:
severity = getattr(sev_obj, "severity", "Low")
else:
severity = "Low"
predicted_days = 0.0
if predictions and j < len(predictions):
pred_obj = predictions[j]
if isinstance(pred_obj, RegressionPrediction):
predicted_days = pred_obj.predictedDays
elif isinstance(pred_obj, dict):
predicted_days = pred_obj.get("predictedDays", 0.0)
elif approximate_avg_days:
sev_map = {
"Low": 1.2,
"Medium": 2.2,
"High": 3.0,
"Critical": 4.0,
}
predicted_days = sev_map.get(severity, 2.0)
status = (record.get("Status") or "").strip().lower()
if category not in categories_agg:
categories_agg[category] = {
"count": 0,
"severities": Counter(),
"actions": [],
"resolution_days": [],
"hubs": Counter(),
"airlines": Counter(),
"open_count": 0,
"closed_count": 0,
"high_priority_count": 0,
"records": [],
}
cat_data = categories_agg[category]
cat_data["count"] += 1
cat_data["severities"][severity] += 1
if use_regression or approximate_avg_days:
cat_data["resolution_days"].append(predicted_days)
all_severities[severity] += 1
if use_regression or approximate_avg_days:
total_resolution_days.append(predicted_days)
hub = record.get("HUB") or "Unknown"
airline = record.get("Airlines") or "Unknown"
cat_data["hubs"][hub] += 1
cat_data["airlines"][airline] += 1
if status == "closed":
cat_data["closed_count"] += 1
total_closed += 1
else:
cat_data["open_count"] += 1
total_open += 1
if severity in ("Critical", "High"):
cat_data["high_priority_count"] += 1
total_high_priority += 1
if len(cat_data["records"]) < 20:
cat_data["records"].append(record)
for category, cat_data in categories_agg.items():
records = cat_data["records"]
if records:
sample_record = records[0]
report = sample_record.get("Report", "") or ""
area = (sample_record.get("Area", "") or "").replace(" Area", "")
airline = sample_record.get("Airlines", "")
branch_val = sample_record.get("Branch", "")
severity_counts = cat_data["severities"]
dominant_severity = (
severity_counts.most_common(1)[0][0] if severity_counts else "Medium"
)
recs = action_service.recommend(
report=report,
issue_type=category,
severity=dominant_severity,
area=area if area else None,
airline=airline if airline else None,
branch=branch_val if branch_val else None,
top_n=top_n_per_category,
)
cat_data["actions"] = recs.get("recommendations", [])
cat_data["effectiveness"] = recs.get("effectiveness_score", 0.5)
else:
cat_data["actions"] = []
cat_data["effectiveness"] = 0.5
category_summaries: Dict[str, ActionCategorySummary] = {}
for category, cat_data in categories_agg.items():
avg_days = None
if cat_data["resolution_days"]:
avg_days = round(
sum(cat_data["resolution_days"]) / len(cat_data["resolution_days"]), 2
)
top_actions = []
for action in cat_data["actions"][:top_n_per_category]:
top_actions.append(
{
"action": action.get("action", ""),
"priority": action.get("priority", "MEDIUM"),
"source": action.get("source", "template"),
"rationale": action.get("rationale", ""),
"confidence": action.get("confidence", 0.5),
}
)
category_summaries[category] = ActionCategorySummary(
count=cat_data["count"],
severityDistribution=dict(cat_data["severities"]),
topActions=top_actions,
avgResolutionDays=avg_days,
topHubs=[h for h, _ in cat_data["hubs"].most_common(5)],
topAirlines=[a for a, _ in cat_data["airlines"].most_common(5)],
effectivenessScore=round(cat_data["effectiveness"], 3),
openCount=cat_data["open_count"],
closedCount=cat_data["closed_count"],
highPriorityCount=cat_data["high_priority_count"],
)
for action in cat_data["actions"]:
action["category"] = category
all_actions.append(action)
overall_avg_days = None
if total_resolution_days:
overall_avg_days = round(
sum(total_resolution_days) / len(total_resolution_days), 2
)
overall_summary = {
"totalRecords": total_records,
"openCount": total_open,
"closedCount": total_closed,
"highPriorityCount": total_high_priority,
"severityDistribution": dict(all_severities),
"avgResolutionDays": overall_avg_days,
"categoriesCount": len(categories_agg),
"avgDaysSource": ("approx" if (fast and approximate_avg_days and not use_regression) else ("model" if use_regression else None)),
}
top_by_count = sorted(
[
{"category": k, "count": v.count, "highPriority": v.highPriorityCount}
for k, v in category_summaries.items()
],
key=lambda x: -x["count"],
)[:10]
top_by_risk = sorted(
[
{
"category": k,
"riskScore": v.highPriorityCount / max(v.count, 1),
"count": v.count,
}
for k, v in category_summaries.items()
],
key=lambda x: (-x["riskScore"], -x["count"]),
)[:10]
global_recs = []
action_priority_order = {"HIGH": 0, "MEDIUM": 1, "LOW": 2}
all_actions.sort(
key=lambda x: (
action_priority_order.get(x.get("priority", "LOW"), 2),
-x.get("confidence", 0),
)
)
seen_actions = set()
for action in all_actions:
key = action.get("action", "")[:50]
if key not in seen_actions and len(global_recs) < 10:
seen_actions.add(key)
global_recs.append(
{
"action": action.get("action", ""),
"priority": action.get("priority", "MEDIUM"),
"category": action.get("category", "Unknown"),
"rationale": action.get("rationale", ""),
"confidence": action.get("confidence", 0.5),
}
)
processing_time = (datetime.now() - start_time).total_seconds()
logger.info(
f"Action summary completed in {processing_time:.2f}s for {total_records} records"
)
resp = ActionSummaryResponse(
status="success",
totalRecords=total_records,
categories=category_summaries,
overallSummary=overall_summary,
topCategoriesByCount=top_by_count,
topCategoriesByRisk=top_by_risk,
globalRecommendations=global_recs,
timestamp=datetime.now().isoformat(),
)
if cache and not bypass_cache and total_records > 0:
try:
cache.set(cache_key, resp.model_dump(), cache_ttl)
except Exception:
pass
return resp
# ============== Advanced NER Endpoints ==============
@app.post("/api/ai/ner/extract")
async def extract_entities(text: str):
"""Extract entities from text"""
from data.advanced_ner_service import get_advanced_ner
ner = get_advanced_ner()
entities = ner.extract(text)
summary = ner.extract_summary(text)
return {
"entities": entities,
"summary": summary,
}
# ============== Similarity Endpoints ==============
@app.post("/api/ai/similar")
async def find_similar_reports(
text: str,
top_k: int = 5,
threshold: float = 0.3,
):
"""Find similar reports"""
from data.similarity_service import get_similarity_service
similarity_service = get_similarity_service()
similar = similarity_service.find_similar(text, top_k, threshold)
return {
"query_preview": text[:100],
"similar_reports": similar,
}
@app.post("/api/ai/similar/build-index")
async def build_similarity_index(bypass_cache: bool = False):
"""Build similarity index from Google Sheets data"""
from data.similarity_service import get_similarity_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
similarity_service = get_similarity_service()
similarity_service.build_index(all_data)
return {
"status": "success",
"records_indexed": len(all_data),
}
# ============== Forecasting Endpoints ==============
@app.get("/api/ai/forecast/issues")
async def forecast_issues(periods: int = 4):
"""Forecast issue volume for next periods"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
forecast = forecast_service.forecast_issues(periods)
return forecast
@app.get("/api/ai/forecast/trends")
async def predict_trends():
"""Predict category trends"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
trends = forecast_service.predict_category_trends()
return trends
@app.get("/api/ai/forecast/seasonal")
async def get_seasonal_patterns():
"""Get seasonal patterns"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
patterns = forecast_service.get_seasonal_patterns()
return patterns
@app.post("/api/ai/forecast/build")
async def build_forecast_data(bypass_cache: bool = False):
"""Build forecast historical data from Google Sheets"""
from data.forecast_service import get_forecast_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
forecast_service = get_forecast_service()
forecast_service.build_historical_data(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"forecast_summary": forecast_service.get_forecast_summary(),
}
# ============== Report Generation Endpoints ==============
@app.post("/api/ai/report/generate")
async def generate_report(
row_id: str,
bypass_cache: bool = False,
):
"""Generate formal incident report"""
from data.report_generator_service import get_report_generator
from data.sheets_service import GoogleSheetsService
from data.risk_service import get_risk_service
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch all data and find the record
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
record = None
for r in all_data:
if r.get("_row_id") == row_id:
record = r
break
if not record:
raise HTTPException(status_code=404, detail=f"Record '{row_id}' not found")
# Generate analysis
report_text = record.get("Report", "") + " " + record.get("Root_Caused", "")
analysis = {
"severity": model_service._classify_severity_fallback([report_text])[0].get(
"severity", "Medium"
),
"issueType": record.get("Irregularity_Complain_Category", ""),
}
# Get risk data
risk_service = get_risk_service()
airline = record.get("Airlines", "")
risk_data = risk_service.get_airline_risk(airline)
# Generate report
report_gen = get_report_generator()
formal_report = report_gen.generate_incident_report(record, analysis, risk_data)
exec_summary = report_gen.generate_executive_summary(record, analysis)
json_report = report_gen.generate_json_report(record, analysis, risk_data)
return {
"row_id": row_id,
"formal_report": formal_report,
"executive_summary": exec_summary,
"structured_report": json_report,
}
# ============== Dashboard Endpoints ==============
@app.get("/api/ai/dashboard/summary")
async def dashboard_summary(bypass_cache: bool = False):
"""Get comprehensive dashboard summary"""
from data.risk_service import get_risk_service
from data.forecast_service import get_forecast_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch data
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
# Get risk summary
risk_service = get_risk_service()
risk_summary = risk_service.get_risk_summary()
# Get forecast summary
forecast_service = get_forecast_service()
forecast_summary = forecast_service.get_forecast_summary()
# Calculate statistics
severity_dist = Counter()
category_dist = Counter()
airline_dist = Counter()
for record in all_data:
report_text = record.get("Report", "") + " " + record.get("Root_Caused", "")
sev = model_service._classify_severity_fallback([report_text])[0].get(
"severity", "Low"
)
severity_dist[sev] += 1
category_dist[record.get("Irregularity_Complain_Category", "Unknown")] += 1
airline_dist[record.get("Airlines", "Unknown")] += 1
return {
"total_records": len(all_data),
"sheets": {
"non_cargo": len(non_cargo),
"cargo": len(cargo),
},
"severity_distribution": dict(severity_dist),
"category_distribution": dict(category_dist.most_common(10)),
"top_airlines": dict(airline_dist.most_common(10)),
"risk_summary": risk_summary,
"forecast_summary": forecast_summary,
"model_status": {
"regression": model_service.model_loaded,
"nlp": model_service.nlp_service is not None,
},
"last_updated": datetime.now().isoformat(),
}
# ============== Seasonality Endpoints ==============
@app.get("/api/ai/seasonality/summary")
async def seasonality_summary(category_type: Optional[str] = None):
"""
Get seasonality summary and patterns
Args:
category_type: "landside_airside", "cgo", or None for both
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.get_seasonality_summary(category_type)
@app.get("/api/ai/seasonality/forecast")
async def seasonality_forecast(
category_type: Optional[str] = None,
periods: int = 4,
granularity: str = "weekly",
):
"""
Forecast issue volumes
Args:
category_type: "landside_airside", "cgo", or None for both
periods: Number of periods to forecast
granularity: "daily", "weekly", or "monthly"
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.forecast(category_type, periods, granularity)
@app.get("/api/ai/seasonality/peaks")
async def seasonality_peaks(
category_type: Optional[str] = None, threshold: float = 1.2
):
"""
Identify peak periods
Args:
category_type: "landside_airside", "cgo", or None for both
threshold: Multiplier above average (1.2 = 20% above)
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.get_peak_periods(category_type, threshold)
@app.post("/api/ai/seasonality/build")
async def build_seasonality_patterns(bypass_cache: bool = False):
"""Build seasonality patterns from Google Sheets data"""
from data.seasonality_service import get_seasonality_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
service = get_seasonality_service()
result = service.build_patterns(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"patterns": result,
}
# ============== Root Cause Endpoints ==============
@app.post("/api/ai/root-cause/classify")
async def classify_root_cause(
root_cause: str,
report: Optional[str] = None,
area: Optional[str] = None,
category: Optional[str] = None,
):
"""
Classify a root cause text into categories
Categories: Equipment Failure, Staff Competency, Process/Procedure,
Communication, External Factors, Documentation, Training Gap, Resource/Manpower
"""
from data.root_cause_service import get_root_cause_service
service = get_root_cause_service()
context = {"area": area, "category": category}
result = service.classify(root_cause, report or "", context)
return result
@app.post("/api/ai/root-cause/classify-batch")
async def classify_root_cause_batch(bypass_cache: bool = False):
"""Classify root causes for all records"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
service = get_root_cause_service()
results = service.classify_batch(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"classifications": results[:100],
"total_classified": len(
[r for r in results if r["primary_category"] != "Unknown"]
),
}
@app.get("/api/ai/root-cause/categories")
async def get_root_cause_categories():
"""Get all available root cause categories"""
from data.root_cause_service import get_root_cause_service
service = get_root_cause_service()
return service.get_categories()
@app.get("/api/ai/root-cause/stats")
async def get_root_cause_stats(bypass_cache: bool = False):
"""Get root cause statistics from all data"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
service = get_root_cause_service()
stats = service.get_statistics(all_data)
return stats
@app.post("/api/ai/root-cause/train")
async def train_root_cause_classifier(bypass_cache: bool = False):
"""Train root cause classifier from historical data"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
service = get_root_cause_service()
result = service.train_from_data(all_data)
return result
# ============== Category Summarization Endpoints ==============
class CategorySummaryResponse(BaseModel):
status: str
category_type: str
summary: Dict[str, Any]
timestamp: str
@app.get("/api/ai/summarize", response_model=CategorySummaryResponse)
async def summarize_by_category(category: str = "all", bypass_cache: bool = False):
"""
Get summarized insights for Non-cargo and/or CGO categories
Query Parameters:
category: "non_cargo", "cgo", or "all" (default: "all")
bypass_cache: Skip cache and fetch fresh data (default: false)
Returns aggregated summary including:
- Severity distribution
- Top categories, airlines, hubs, branches
- Key insights and recommendations
- Common issues
- Monthly trends
"""
from data.category_summarization_service import get_category_summarization_service
from data.sheets_service import GoogleSheetsService
valid_categories = ["non_cargo", "cgo", "all"]
if category.lower() not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}",
)
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
summarization_service = get_category_summarization_service()
summary = summarization_service.summarize_category(all_data, category.lower())
return CategorySummaryResponse(
status="success",
category_type=category.lower(),
summary=summary,
timestamp=datetime.now().isoformat(),
)
@app.get("/api/ai/summarize/non-cargo")
async def summarize_non_cargo(bypass_cache: bool = False):
"""Quick endpoint for Non-cargo summary"""
return await summarize_by_category(category="non_cargo", bypass_cache=bypass_cache)
@app.get("/api/ai/summarize/cgo")
async def summarize_cgo(bypass_cache: bool = False):
"""Quick endpoint for CGO (Cargo) summary"""
return await summarize_by_category(category="cgo", bypass_cache=bypass_cache)
@app.get("/api/ai/summarize/compare")
async def compare_categories(bypass_cache: bool = False):
"""Compare Non-cargo and CGO categories side by side"""
return await summarize_by_category(category="all", bypass_cache=bypass_cache)
# ============== Branch Analytics Endpoints ==============
@app.get("/api/ai/branch/summary")
async def branch_analytics_summary(category_type: Optional[str] = None):
"""
Get branch analytics summary
Args:
category_type: "landside_airside", "cgo", or None for both
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_summary(category_type)
@app.get("/api/ai/branch/{branch_name}")
async def get_branch_metrics(branch_name: str, category_type: Optional[str] = None):
"""
Get metrics for a specific branch
Args:
branch_name: Branch name
category_type: "landside_airside", "cgo", or None for combined
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
data = service.get_branch(branch_name, category_type)
if not data:
raise HTTPException(status_code=404, detail=f"Branch '{branch_name}' not found")
return data
@app.get("/api/ai/branch/ranking")
async def branch_ranking(
category_type: Optional[str] = None,
sort_by: str = "risk_score",
limit: int = 20,
):
"""
Get branch ranking
Args:
category_type: "landside_airside", "cgo", or None for both
sort_by: Field to sort by (risk_score, total_issues, critical_high_count)
limit: Maximum branches to return
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_ranking(category_type, sort_by, limit)
@app.get("/api/ai/branch/comparison")
async def branch_comparison():
"""Compare all branches across category types"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_comparison()
@app.post("/api/ai/branch/calculate")
async def calculate_branch_metrics(bypass_cache: bool = False):
"""Calculate branch metrics from Google Sheets data"""
from data.branch_analytics_service import get_branch_analytics_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
service = get_branch_analytics_service()
result = service.calculate_branch_metrics(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"metrics": result,
}
# ============== Main ==============
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("API_PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)