gapura-ai-api / api /main.py
Muhammad Ridzki Nugraha
Upload folder using huggingface_hub
c5af9d3 verified
"""
Gapura AI Analysis API
FastAPI server for regression and NLP analysis of irregularity reports
Uses real trained models from ai-model/models/
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
from pydantic_core import ValidationError
from typing import List, Optional, Dict, Any, Tuple
from collections import Counter
import os
import json
import logging
from datetime import datetime
import numpy as np
import pickle
import pandas as pd
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.cache_service import get_cache, CacheService
from data.nlp_service import NLPModelService
from data.shap_service import get_shap_explainer
from data.anomaly_service import get_anomaly_detector
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
tags_metadata = [
{
"name": "Analysis",
"description": "Core AI analysis endpoints for irregularity reports.",
},
{
"name": "Health",
"description": "System health and model status checks.",
},
{
"name": "Jobs",
"description": "Asynchronous job management.",
},
{
"name": "Training",
"description": "Model retraining and lifecycle management.",
},
]
app = FastAPI(
title="Gapura AI Analysis API",
description="""
Gapura AI Analysis API provides advanced machine learning capabilities for analyzing irregularity reports.
## Features
* **Regression Analysis**: Predict resolution time (days) based on report details.
* **NLP Classification**: Determine severity (Critical, High, Medium, Low) and categorize issues.
* **Entity Extraction**: Extract key entities like Airlines, Flight Numbers, and Dates.
* **Summarization**: Generate executive summaries and key points from long reports.
* **Trend Analysis**: Analyze trends by Airline, Hub, and Category.
* **Anomaly Detection**: Identify unusual patterns in resolution times.
## Models
* **Regression**: Random Forest Regressor (v1.0.0-trained)
* **NLP**: Hybrid Transformer + Rule-based System (v4.0.0-onnx)
""",
version="2.1.0",
openapi_tags=tags_metadata,
docs_url="/docs",
redoc_url="/redoc",
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=500)
@app.exception_handler(ValidationError)
async def validation_exception_handler(request: Request, exc: ValidationError):
return JSONResponse(
status_code=422,
content={
"detail": "Validation error",
"errors": exc.errors(),
"body": exc.json(),
},
)
# ============== Pydantic Models ==============
from enum import Enum
from datetime import date as date_type
class ReportCategoryEnum(str, Enum):
IRREGULARITY = "Irregularity"
COMPLAINT = "Complaint"
class AreaEnum(str, Enum):
APRON = "Apron Area"
TERMINAL = "Terminal Area"
GENERAL = "General"
class StatusEnum(str, Enum):
OPEN = "Open"
CLOSED = "Closed"
IN_PROGRESS = "In Progress"
class IrregularityReport(BaseModel):
Date_of_Event: Optional[str] = Field(None, description="Date of the event")
Airlines: Optional[str] = Field(None, max_length=100)
Flight_Number: Optional[str] = Field(None, max_length=20)
Branch: Optional[str] = Field(None, max_length=10)
HUB: Optional[str] = Field(None, max_length=20)
Route: Optional[str] = Field(None, max_length=50)
Report_Category: Optional[str] = Field(None, max_length=50)
Irregularity_Complain_Category: Optional[str] = Field(None, max_length=100)
Report: Optional[str] = Field(None, max_length=2000)
Root_Caused: Optional[str] = Field(None, max_length=2000)
Action_Taken: Optional[str] = Field(None, max_length=2000)
Area: Optional[str] = Field(None, max_length=50)
Status: Optional[str] = Field(None, max_length=50)
Reported_By: Optional[str] = Field(None, max_length=100)
Upload_Irregularity_Photo: Optional[str] = Field(None)
model_config = {"extra": "allow"}
class AnalysisOptions(BaseModel):
predictResolutionTime: bool = Field(
default=True, description="Run regression model"
)
classifySeverity: bool = Field(
default=True, description="Classify severity using NLP"
)
extractEntities: bool = Field(
default=True, description="Extract entities using NER"
)
generateSummary: bool = Field(default=True, description="Generate text summaries")
analyzeTrends: bool = Field(default=True, description="Analyze trends")
bypassCache: bool = Field(
default=False, description="Bypass cache and fetch fresh data"
)
includeRisk: bool = Field(default=False, description="Include risk assessment in analysis")
class AnalysisRequest(BaseModel):
sheetId: Optional[str] = Field(None, description="Google Sheet ID")
sheetName: Optional[str] = Field(None, description="Sheet name (NON CARGO or CGO)")
rowRange: Optional[str] = Field(None, description="Row range (e.g., A2:Z100)")
data: Optional[List[IrregularityReport]] = Field(
None, description="Direct data upload"
)
options: AnalysisOptions = Field(default_factory=AnalysisOptions)
@field_validator("data")
@classmethod
def validate_data(cls, v):
if v is not None and len(v) == 0:
raise ValueError("data array cannot be empty")
return v
class ShapExplanation(BaseModel):
baseValue: float = Field(description="Base/expected value from model")
predictionExplained: bool = Field(
description="Whether SHAP explanation is available"
)
topFactors: List[Dict[str, Any]] = Field(
default_factory=list, description="Top contributing features"
)
explanation: str = Field(default="", description="Human-readable explanation")
class AnomalyResult(BaseModel):
isAnomaly: bool = Field(description="Whether prediction is anomalous")
anomalyScore: float = Field(description="Anomaly score (0-1)")
anomalies: List[Dict[str, Any]] = Field(
default_factory=list, description="List of detected anomalies"
)
class RegressionPrediction(BaseModel):
reportId: str
predictedDays: float
confidenceInterval: Tuple[float, float]
featureImportance: Dict[str, float]
hasUnknownCategories: bool = Field(
default=False, description="True if unknown categories were used in prediction"
)
shapExplanation: Optional[ShapExplanation] = Field(
default=None, description="SHAP-based explanation for prediction"
)
anomalyDetection: Optional[AnomalyResult] = Field(
default=None, description="Anomaly detection results"
)
class RegressionResult(BaseModel):
predictions: List[RegressionPrediction]
modelMetrics: Dict[str, Any]
class ClassificationResult(BaseModel):
reportId: str
severity: str
severityConfidence: float
areaType: str
issueType: str
issueTypeConfidence: float
class Entity(BaseModel):
text: str
label: str
start: int
end: int
confidence: float
class EntityResult(BaseModel):
reportId: str
entities: List[Entity]
class SummaryResult(BaseModel):
reportId: str
executiveSummary: str
keyPoints: List[str]
class SentimentResult(BaseModel):
reportId: str
urgencyScore: float
sentiment: str
keywords: List[str]
class NLPResult(BaseModel):
classifications: List[ClassificationResult]
entities: List[EntityResult]
summaries: List[SummaryResult]
sentiment: List[SentimentResult]
class TrendData(BaseModel):
count: int
avgResolutionDays: Optional[float]
topIssues: List[str]
class TrendResult(BaseModel):
byAirline: Dict[str, TrendData]
byHub: Dict[str, TrendData]
byCategory: Dict[str, Dict[str, Any]]
timeSeries: List[Dict[str, Any]]
class Metadata(BaseModel):
totalRecords: int
processingTime: float
modelVersions: Dict[str, str]
class AnalysisResponse(BaseModel):
regression: Optional[RegressionResult] = None
nlp: Optional[NLPResult] = None
trends: Optional[TrendResult] = None
risk: Optional[RiskAssessmentResponse] = None
metadata: Metadata
class RiskItem(BaseModel):
reportId: str
severity: str
severityConfidence: float
predictedDays: float
anomalyScore: float
category: str
hub: str
area: str
riskScore: float
priority: str
recommendedActions: List[Dict[str, Any]] = Field(default_factory=list)
preventiveSuggestions: List[str] = Field(default_factory=list)
class RiskAssessmentResponse(BaseModel):
items: List[RiskItem]
topPatterns: List[Dict[str, Any]]
metadata: Dict[str, Any]
def _severity_to_score(level: str) -> float:
m = {"Critical": 1.0, "High": 0.8, "Medium": 0.5, "Low": 0.2}
return m.get(level, 0.3)
def _normalize_days(d: float) -> float:
return max(0.0, min(1.0, float(d) / 7.0))
def _priority_from_score(s: float) -> str:
if s >= 0.75:
return "HIGH"
if s >= 0.45:
return "MEDIUM"
return "LOW"
def _extract_prevention(texts: List[str]) -> List[str]:
kws = ["review", "prosedur", "procedure", "training", "pelatihan", "prevent", "pencegahan", "maintenance", "inspection", "inspeksi", "briefing", "supervision", "checklist", "verify", "verifikasi"]
out = []
seen = set()
for t in texts:
lt = t.lower()
for k in kws:
if k in lt:
if t not in seen:
seen.add(t)
out.append(t)
return out[:5]
# ============== Real Model Service ==============
class ModelService:
"""Service that loads and uses real trained models"""
def __init__(self):
self.regression_version = "1.0.0-trained"
self.nlp_version = "1.0.0-mock"
self.regression_model = None
self.regression_onnx_session = None
self.label_encoders = {}
self.scaler = None
self.feature_names = []
self.model_metrics = {}
self.model_loaded = False
self.nlp_service = None
self._load_regression_model()
self._load_nlp_service()
def _load_nlp_service(self):
"""Load NLP service with trained models or fallback"""
try:
from data.nlp_service import get_nlp_service
self.nlp_service = get_nlp_service()
self.nlp_version = self.nlp_service.version
logger.info(f"NLP service loaded (version: {self.nlp_version})")
except Exception as e:
logger.warning(f"Failed to load NLP service: {e}")
def _load_regression_model(self):
"""Load the trained regression model from file"""
try:
model_path = os.path.join(
os.path.dirname(__file__),
"..",
"models",
"regression",
"resolution_predictor_latest.pkl",
)
if not os.path.exists(model_path):
logger.warning(f"Model file not found at {model_path}")
return
logger.info(f"Loading regression model from {model_path}")
with open(model_path, "rb") as f:
model_data = pickle.load(f)
self.regression_model = model_data.get("model")
self.label_encoders = model_data.get("label_encoders", {})
self.scaler = model_data.get("scaler")
self.feature_names = model_data.get("feature_names", [])
self.model_metrics = model_data.get("metrics", {})
self.model_loaded = True
logger.info(f"✓ Regression model loaded successfully")
logger.info(f" - Features: {len(self.feature_names)}")
logger.info(f" - Metrics: MAE={self.model_metrics.get('test_mae', 'N/A')}")
# Try to load ONNX model for faster inference
onnx_path = os.path.join(
os.path.dirname(__file__),
"..",
"models",
"regression",
"resolution_predictor.onnx",
)
if os.path.exists(onnx_path):
try:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.regression_onnx_session = ort.InferenceSession(onnx_path, sess_options)
logger.info("✓ Regression ONNX model loaded successfully")
except Exception as e:
logger.warning(f"Failed to load Regression ONNX model: {e}")
except Exception as e:
logger.error(f"Failed to load regression model: {e}")
self.model_loaded = False
def _extract_features(self, report: Dict) -> Optional[np.ndarray]:
"""Extract features from a single report matching training preprocessing"""
try:
# Parse date
date_str = report.get("Date_of_Event", "")
try:
date_obj = pd.to_datetime(date_str, errors="coerce")
if pd.isna(date_obj):
date_obj = datetime.now()
day_of_week = date_obj.dayofweek
month = date_obj.month
is_weekend = day_of_week in [5, 6]
week_of_year = date_obj.isocalendar().week
day_of_year = date_obj.dayofyear
except:
day_of_week = 0
month = 1
is_weekend = False
week_of_year = 1
day_of_year = 1
sin_day_of_week = np.sin(2 * np.pi * day_of_week / 7)
cos_day_of_week = np.cos(2 * np.pi * day_of_week / 7)
sin_month = np.sin(2 * np.pi * month / 12)
cos_month = np.cos(2 * np.pi * month / 12)
sin_day_of_year = np.sin(2 * np.pi * day_of_year / 365)
cos_day_of_year = np.cos(2 * np.pi * day_of_year / 365)
# Text features
report_text = report.get("Report", "")
root_cause = report.get("Root_Caused", "")
action_taken = report.get("Action_Taken", "")
# Categorical
airline = report.get("Airlines", "Unknown")
hub = report.get("HUB", "Unknown")
branch = report.get("Branch", "Unknown")
category = report.get("Irregularity_Complain_Category", "Unknown")
area = report.get("Area", "Unknown")
# Binary features
has_photos = bool(report.get("Upload_Irregularity_Photo", ""))
is_complaint = report.get("Report_Category", "") == "Complaint"
# Encode categorical features
categorical_values = {
"airline": airline,
"hub": hub,
"branch": branch,
"category": category,
"area": area,
}
encoded_values = {}
unknown_flags = {}
for col, value in categorical_values.items():
if col in self.label_encoders:
le = self.label_encoders[col]
value_str = str(value)
if value_str in le.classes_:
encoded_values[f"{col}_encoded"] = le.transform([value_str])[0]
unknown_flags[col] = False
else:
unknown_idx = (
le.transform(["Unknown"])[0]
if "Unknown" in le.classes_
else 0
)
encoded_values[f"{col}_encoded"] = unknown_idx
unknown_flags[col] = True
logger.warning(
f"Unknown {col} value: '{value_str}' - using Unknown category"
)
else:
encoded_values[f"{col}_encoded"] = 0
unknown_flags[col] = True
# Build feature vector in correct order
feature_dict = {
"day_of_week": day_of_week,
"month": month,
"is_weekend": int(is_weekend),
"week_of_year": week_of_year,
"sin_day_of_week": sin_day_of_week,
"cos_day_of_week": cos_day_of_week,
"sin_month": sin_month,
"cos_month": cos_month,
"sin_day_of_year": sin_day_of_year,
"cos_day_of_year": cos_day_of_year,
"report_length": len(report_text),
"report_word_count": len(report_text.split()) if report_text else 0,
"root_cause_length": len(root_cause),
"action_taken_length": len(action_taken),
"has_photos": int(has_photos),
"is_complaint": int(is_complaint),
"text_complexity": (len(report_text) * len(report_text.split()) / 100)
if report_text
else 0,
"has_root_cause": int(bool(root_cause)),
"has_action_taken": int(bool(action_taken)),
}
feature_dict.update(encoded_values)
has_unknown_categories = any(unknown_flags.values())
# Create feature array in correct order
features = []
for feature_name in self.feature_names:
features.append(feature_dict.get(feature_name, 0))
X = np.array([features])
# Scale features
if self.scaler:
X = self.scaler.transform(X)
return X, has_unknown_categories
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None, True
def _extract_features_batch(self, df: pd.DataFrame) -> Tuple[Optional[np.ndarray], np.ndarray]:
"""Extract features from a dataframe matching training preprocessing (Batch optimized)"""
try:
# Ensure required columns exist
required_cols = [
"Date_of_Event", "Report", "Root_Caused", "Action_Taken",
"Upload_Irregularity_Photo", "Report_Category",
"Airlines", "HUB", "Branch", "Irregularity_Complain_Category", "Area"
]
for col in required_cols:
if col not in df.columns:
df[col] = None
# Copy to avoid modifying original
df = df.copy()
# Parse date
df["Date_of_Event"] = pd.to_datetime(df["Date_of_Event"], errors="coerce")
now = datetime.now()
df["Date_of_Event"] = df["Date_of_Event"].fillna(now)
df["day_of_week"] = df["Date_of_Event"].dt.dayofweek
df["month"] = df["Date_of_Event"].dt.month
df["is_weekend"] = df["day_of_week"].isin([5, 6]).astype(int)
df["week_of_year"] = df["Date_of_Event"].dt.isocalendar().week.astype(int)
df["day_of_year"] = df["Date_of_Event"].dt.dayofyear
# Sin/Cos transforms
df["sin_day_of_week"] = np.sin(2 * np.pi * df["day_of_week"] / 7)
df["cos_day_of_week"] = np.cos(2 * np.pi * df["day_of_week"] / 7)
df["sin_month"] = np.sin(2 * np.pi * df["month"] / 12)
df["cos_month"] = np.cos(2 * np.pi * df["month"] / 12)
df["sin_day_of_year"] = np.sin(2 * np.pi * df["day_of_year"] / 365)
df["cos_day_of_year"] = np.cos(2 * np.pi * df["day_of_year"] / 365)
# Text features
df["Report"] = df["Report"].fillna("").astype(str)
df["Root_Caused"] = df["Root_Caused"].fillna("").astype(str)
df["Action_Taken"] = df["Action_Taken"].fillna("").astype(str)
df["report_length"] = df["Report"].str.len()
df["report_word_count"] = df["Report"].apply(lambda x: len(x.split()) if x else 0)
df["root_cause_length"] = df["Root_Caused"].str.len()
df["action_taken_length"] = df["Action_Taken"].str.len()
df["has_photos"] = df["Upload_Irregularity_Photo"].fillna("").astype(bool).astype(int)
df["is_complaint"] = (df["Report_Category"] == "Complaint").astype(int)
df["text_complexity"] = np.where(
df["Report"].str.len() > 0,
(df["report_length"] * df["report_word_count"] / 100),
0
)
df["has_root_cause"] = (df["Root_Caused"].str.len() > 0).astype(int)
df["has_action_taken"] = (df["Action_Taken"].str.len() > 0).astype(int)
# Categorical encoding
categorical_cols = {
"airline": "Airlines",
"hub": "HUB",
"branch": "Branch",
"category": "Irregularity_Complain_Category",
"area": "Area"
}
unknown_flags = np.zeros(len(df), dtype=bool)
for feature_name, col_name in categorical_cols.items():
df[col_name] = df[col_name].fillna("Unknown").astype(str)
if feature_name in self.label_encoders:
le = self.label_encoders[feature_name]
# Create mapping for fast lookup
mapping = {label: idx for idx, label in enumerate(le.classes_)}
unknown_idx = mapping.get("Unknown", 0)
if "Unknown" in le.classes_:
unknown_idx = mapping["Unknown"]
# Map values
encoded_col = df[col_name].map(mapping)
# Track unknowns (NaN after map means unknown)
is_unknown = encoded_col.isna()
unknown_flags |= is_unknown.values
# Fill unknowns
df[f"{feature_name}_encoded"] = encoded_col.fillna(unknown_idx).astype(int)
else:
df[f"{feature_name}_encoded"] = 0
unknown_flags[:] = True
# Select features in order
for f in self.feature_names:
if f not in df.columns:
df[f] = 0
X = df[self.feature_names].values
# Scale
if self.scaler:
X = self.scaler.transform(X)
return X, unknown_flags
except Exception as e:
logger.error(f"Batch feature extraction error: {e}")
return None, np.ones(len(df), dtype=bool)
def predict_regression(self, data: List[Dict]) -> List[RegressionPrediction]:
"""Predict resolution time using trained model"""
predictions = []
shap_explainer = get_shap_explainer()
anomaly_detector = get_anomaly_detector()
# Batch processing
try:
df = pd.DataFrame(data)
X_batch, unknown_flags_batch = self._extract_features_batch(df)
if X_batch is not None:
if self.regression_onnx_session:
# Use ONNX model
input_name = self.regression_onnx_session.get_inputs()[0].name
predicted_batch = self.regression_onnx_session.run(None, {input_name: X_batch.astype(np.float32)})[0]
predicted_batch = predicted_batch.ravel() # Flatten to 1D array
elif self.regression_model is not None:
# Use Pickle model
predicted_batch = self.regression_model.predict(X_batch)
else:
predicted_batch = None
unknown_flags_batch = [True] * len(data)
else:
predicted_batch = None
unknown_flags_batch = [True] * len(data)
except Exception as e:
logger.error(f"Batch prediction setup failed: {e}")
predicted_batch = None
unknown_flags_batch = [True] * len(data)
for i, item in enumerate(data):
# Use batch results
has_unknown = unknown_flags_batch[i]
features = X_batch[i:i+1] if X_batch is not None else None
category = item.get("Irregularity_Complain_Category", "Unknown")
hub = item.get("HUB", "Unknown")
if predicted_batch is not None:
predicted = predicted_batch[i]
mae = self.model_metrics.get("test_mae", 0.5)
lower = max(0.1, predicted - mae)
upper = predicted + mae
shap_exp = None
if shap_explainer.explainer is not None and features is not None:
try:
shap_result = shap_explainer.explain_prediction(features)
shap_exp = ShapExplanation(
baseValue=shap_result.get("base_value", 0),
predictionExplained=shap_result.get(
"prediction_explained", False
),
topFactors=shap_result.get("top_factors", [])[:5],
explanation=shap_result.get("explanation", ""),
)
except Exception as e:
logger.debug(f"SHAP explanation failed: {e}")
anomaly_result = None
try:
anomaly_data = anomaly_detector.detect_prediction_anomaly(
predicted, category, hub
)
anomaly_result = AnomalyResult(
isAnomaly=anomaly_data.get("is_anomaly", False),
anomalyScore=anomaly_data.get("anomaly_score", 0),
anomalies=anomaly_data.get("anomalies", []),
)
except Exception as e:
logger.debug(f"Anomaly detection failed: {e}")
else:
base_days = {
"Cargo Problems": 2.5,
"Pax Handling": 1.8,
"GSE": 3.2,
"Operation": 2.1,
"Baggage Handling": 1.5,
}.get(category, 2.0)
predicted = base_days + np.random.normal(0, 0.3)
lower = max(0.1, predicted - 0.5)
upper = predicted + 0.5
has_unknown = True
shap_exp = None
anomaly_result = None
if self.model_metrics and "feature_importance" in self.model_metrics:
importance = self.model_metrics["feature_importance"]
else:
importance = {
"category": 0.35,
"airline": 0.28,
"hub": 0.15,
"reportLength": 0.12,
"hasPhotos": 0.10,
}
predictions.append(
RegressionPrediction(
reportId=f"row_{i}",
predictedDays=round(max(0.1, predicted), 2),
confidenceInterval=(round(lower, 2), round(upper, 2)),
featureImportance=importance,
hasUnknownCategories=has_unknown,
shapExplanation=shap_exp,
anomalyDetection=anomaly_result,
)
)
return predictions
def classify_text(self, data: List[Dict]) -> List[ClassificationResult]:
"""Classify text using trained NLP models or rule-based fallback"""
results = []
texts = [
(item.get("Report") or "") + " " + (item.get("Root_Caused") or "")
for item in data
]
# Get multi-task predictions if available
mt_results = None
if self.nlp_service:
mt_results = self.nlp_service.predict_multi_task(texts)
severity_results = self.nlp_service.classify_severity(texts)
else:
severity_results = self._classify_severity_fallback(texts)
for i, (item, sev_result) in enumerate(zip(data, severity_results)):
severity = sev_result.get("severity", "Low")
severity_conf = sev_result.get("confidence", 0.8)
# Use multi-task predictions for area and issue type if available
if mt_results and i < len(mt_results):
mt_res = mt_results[i]
area = mt_res.get("area", {}).get("label", item.get("Area", "Unknown")).replace(" Area", "")
area_conf = mt_res.get("area", {}).get("confidence", 0.85)
issue = mt_res.get("irregularity", {}).get("label", item.get("Irregularity_Complain_Category", "Unknown"))
issue_conf = mt_res.get("irregularity", {}).get("confidence", 0.85)
else:
area = item.get("Area", "Unknown").replace(" Area", "")
area_conf = 0.85
issue = item.get("Irregularity_Complain_Category", "Unknown")
issue_conf = 0.85
results.append(
ClassificationResult(
reportId=f"row_{i}",
severity=severity,
severityConfidence=severity_conf,
areaType=area,
issueType=issue,
issueTypeConfidence=issue_conf,
)
)
return results
def _classify_severity_fallback(self, texts: List[str]) -> List[Dict]:
"""Fallback severity classification"""
results = []
for text in texts:
report = text.lower()
if any(
kw in report
for kw in ["damage", "torn", "broken", "critical", "emergency"]
):
severity = "High"
severity_conf = 0.89
elif any(kw in report for kw in ["delay", "late", "wrong", "error"]):
severity = "Medium"
severity_conf = 0.75
else:
severity = "Low"
severity_conf = 0.82
results.append({"severity": severity, "confidence": severity_conf})
return results
def extract_entities(self, data: List[Dict]) -> List[EntityResult]:
"""Extract entities from reports"""
results = []
for i, item in enumerate(data):
entities = []
report_text = item.get("Report", "") + " " + item.get("Root_Caused", "")
# Extract airline
airline = item.get("Airlines", "")
if airline and airline != "Unknown":
# Find position in text
idx = report_text.lower().find(airline.lower())
start = max(0, idx) if idx >= 0 else 0
entities.append(
Entity(
text=airline,
label="AIRLINE",
start=start,
end=start + len(airline),
confidence=0.95,
)
)
# Extract flight number
flight = item.get("Flight_Number", "")
if flight and flight != "#N/A":
entities.append(
Entity(
text=flight,
label="FLIGHT_NUMBER",
start=0,
end=len(flight),
confidence=0.92,
)
)
# Extract dates
date_str = item.get("Date_of_Event", "")
if date_str:
entities.append(
Entity(
text=date_str,
label="DATE",
start=0,
end=len(date_str),
confidence=0.90,
)
)
results.append(EntityResult(reportId=f"row_{i}", entities=entities))
return results
def generate_summary(self, data: List[Dict]) -> List[SummaryResult]:
"""Generate summaries using NLP service or fallback"""
results = []
for i, item in enumerate(data):
combined_text = (
item.get("Report", "")
+ " "
+ item.get("Root_Caused", "")
+ " "
+ item.get("Action_Taken", "")
)
if self.nlp_service and len(combined_text) > 100:
summary_result = self.nlp_service.summarize(combined_text)
executive_summary = summary_result.get("executiveSummary", "")
key_points = summary_result.get("keyPoints", [])
else:
category = item.get("Irregularity_Complain_Category", "Issue")
report = item.get("Report", "")[:120]
root_cause = item.get("Root_Caused", "")[:80]
action = item.get("Action_Taken", "")[:80]
executive_summary = f"{category}: {report}"
if root_cause:
executive_summary += f" Root cause: {root_cause}."
key_points = [
f"Category: {category}",
f"Status: {item.get('Status', 'Unknown')}",
f"Area: {item.get('Area', 'Unknown')}",
]
if action:
key_points.append(f"Action: {action[:50]}...")
results.append(
SummaryResult(
reportId=f"row_{i}",
executiveSummary=executive_summary,
keyPoints=key_points,
)
)
return results
def analyze_sentiment(self, data: List[Dict]) -> List[SentimentResult]:
"""Analyze sentiment/urgency using NLP service or fallback"""
results = []
texts = [
item.get("Report", "") + " " + item.get("Root_Caused", "") for item in data
]
if self.nlp_service:
urgency_results = self.nlp_service.analyze_urgency(texts)
else:
urgency_results = self._analyze_urgency_fallback(texts)
for i, (item, urg_result) in enumerate(zip(data, urgency_results)):
results.append(
SentimentResult(
reportId=f"row_{i}",
urgencyScore=urg_result.get("urgency_score", 0.0),
sentiment=urg_result.get("sentiment", "Neutral"),
keywords=urg_result.get("keywords", []),
)
)
return results
def _analyze_urgency_fallback(self, texts: List[str]) -> List[Dict]:
"""Fallback urgency analysis"""
urgency_keywords = [
"damage",
"broken",
"emergency",
"critical",
"urgent",
"torn",
"severe",
]
results = []
for text in texts:
report = text.lower()
keyword_matches = [kw for kw in urgency_keywords if kw in report]
urgency_count = len(keyword_matches)
urgency_score = min(1.0, urgency_count / 3.0)
results.append(
{
"urgency_score": round(urgency_score, 2),
"sentiment": "Negative" if urgency_score > 0.3 else "Neutral",
"keywords": keyword_matches,
}
)
return results
# Initialize model service
model_service = ModelService()
# ============== API Endpoints ==============
@app.get(
"/",
tags=["Health"],
summary="API Root & Status",
)
async def root():
"""Returns basic API status, version, and model availability."""
return {
"status": "healthy",
"service": "Gapura AI Analysis API",
"version": "1.0.0",
"models": {
"regression": "loaded" if model_service.model_loaded else "unavailable",
"nlp": model_service.nlp_service.version if model_service.nlp_service and model_service.nlp_service.models_loaded else "unavailable",
},
"timestamp": datetime.now().isoformat(),
}
@app.get(
"/health",
tags=["Health"],
summary="Detailed Health Check",
)
async def health_check():
"""
Returns detailed health status including:
- **Models**: Version and load status of Regression and NLP models.
- **Cache**: Redis/Local cache connectivity.
- **Metrics**: Current model performance metrics (MAE, RMSE, R2).
"""
cache = get_cache()
cache_health = cache.health_check()
return {
"status": "healthy",
"models": {
"regression": {
"version": model_service.regression_version,
"loaded": model_service.model_loaded,
"metrics": model_service.model_metrics
if model_service.model_loaded
else None,
},
"nlp": {
"version": model_service.nlp_version,
"status": "rule_based",
},
},
"cache": cache_health,
"timestamp": datetime.now().isoformat(),
}
@app.post("/api/ai/risk/assess", response_model=RiskAssessmentResponse, tags=["Analysis"])
async def assess_risk(
request: Optional[AnalysisRequest] = Body(None),
sheetId: Optional[str] = None,
sheetName: Optional[str] = None,
rowRange: Optional[str] = None,
bypass_cache: bool = False,
top_k_patterns: int = 5,
):
from data.sheets_service import GoogleSheetsService
from data.action_service import get_action_service
items_data: List[Dict[str, Any]] = []
if request and request.data:
items_data = [r.model_dump(exclude_none=True) for r in request.data]
elif sheetId and sheetName and rowRange:
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
items_data = sheets_service.fetch_sheet_data(sheetId, sheetName, rowRange, bypass_cache=bypass_cache)
else:
raise HTTPException(status_code=400, detail="sheetId, sheetName, and rowRange are required, or provide data in body")
if len(items_data) == 0:
return RiskAssessmentResponse(items=[], topPatterns=[], metadata={"count": 0})
preds = model_service.predict_regression(items_data)
classes = model_service.classify_text(items_data)
try:
action_service = get_action_service()
eff = action_service.action_effectiveness or {}
except Exception:
eff = {}
items: List[RiskItem] = []
for i, item in enumerate(items_data):
cat = item.get("Irregularity_Complain_Category", "Unknown") or "Unknown"
hub = item.get("HUB", "Unknown") or "Unknown"
area = (item.get("Area", "Unknown") or "Unknown").replace(" Area", "")
pr = preds[i]
cl = classes[i]
sev = cl.severity
sev_conf = cl.severityConfidence
pdays = pr.predictedDays
anom = 0.0
if pr.anomalyDetection:
anom = pr.anomalyDetection.anomalyScore
sev_s = _severity_to_score(sev)
d_s = _normalize_days(pdays)
cat_w = 1.0 - float(eff.get(cat, 0.8))
risk = min(1.0, 0.5 * sev_s + 0.25 * d_s + 0.15 * anom + 0.10 * cat_w)
recs: List[Dict[str, Any]] = []
try:
recs_resp = action_service.recommend(
report=item.get("Report", "") or "",
issue_type=cat,
severity=sev,
area=area if area else None,
airline=item.get("Airlines") or None,
top_n=5,
)
recs = recs_resp.get("recommendations", [])
except Exception:
recs = []
prev = _extract_prevention([r.get("action", "") for r in recs])
items.append(
RiskItem(
reportId=f"row_{i}",
severity=sev,
severityConfidence=sev_conf,
predictedDays=pdays,
anomalyScore=anom,
category=cat,
hub=hub,
area=area,
riskScore=round(risk, 3),
priority=_priority_from_score(risk),
recommendedActions=recs[:5],
preventiveSuggestions=prev,
)
)
groups: Dict[str, Dict[str, Any]] = {}
for it, raw in zip(items, items_data):
key = f"{it.category}|{it.hub}|{it.area}"
g = groups.get(key) or {"key": key, "category": it.category, "hub": it.hub, "area": it.area, "count": 0, "avgRisk": 0.0, "avgDays": 0.0, "highSeverityShare": 0.0}
g["count"] += 1
g["avgRisk"] += it.riskScore
g["avgDays"] += it.predictedDays
g["highSeverityShare"] += 1.0 if it.severity in ("Critical", "High") else 0.0
groups[key] = g
patterns = []
for g in groups.values():
c = g["count"]
g["avgRisk"] = round(g["avgRisk"] / max(1, c), 3)
g["avgDays"] = round(g["avgDays"] / max(1, c), 2)
g["highSeverityShare"] = round(g["highSeverityShare"] / max(1, c), 3)
patterns.append(g)
patterns.sort(key=lambda x: (-x["avgRisk"], -x["highSeverityShare"], -x["avgDays"], -x["count"]))
return RiskAssessmentResponse(
items=sorted(items, key=lambda x: -x.riskScore),
topPatterns=patterns[:top_k_patterns],
metadata={"count": len(items)},
)
from data.job_service import JobService, JobStatus
# Initialize job service
job_service = JobService()
def perform_analysis(data: List[Dict], options: AnalysisOptions, compact: bool) -> AnalysisResponse:
"""Core analysis logic reused by sync and async endpoints"""
start_time = datetime.now()
total_records = len(data)
logger.info(f"Analyzing {total_records} records...")
# Initialize response
response = AnalysisResponse(
metadata=Metadata(
totalRecords=total_records,
processingTime=0.0,
modelVersions={
"regression": model_service.regression_version,
"nlp": model_service.nlp_version,
},
)
)
# Regression Analysis
predictions: List[RegressionPrediction] = []
if options.predictResolutionTime or options.includeRisk:
logger.info(f"Running regression analysis...")
predictions = model_service.predict_regression(data)
# Use real metrics if available
if model_service.model_loaded and model_service.model_metrics:
metrics = {
"mae": round(model_service.model_metrics.get("test_mae", 1.2), 3),
"rmse": round(model_service.model_metrics.get("test_rmse", 1.8), 3),
"r2": round(model_service.model_metrics.get("test_r2", 0.78), 3),
"model_loaded": True,
"note": "Using trained model"
if model_service.model_loaded
else "Using fallback",
}
else:
metrics = {
"mae": None,
"rmse": None,
"r2": None,
"model_loaded": False,
"note": "Model not available - using fallback predictions",
}
if options.predictResolutionTime:
response.regression = RegressionResult(
predictions=predictions,
modelMetrics=metrics,
)
# NLP Analysis
classifications: List[ClassificationResult] = []
if any(
[
options.classifySeverity,
options.extractEntities,
options.generateSummary,
options.includeRisk,
]
):
logger.info(f"Running NLP analysis...")
entities = []
summaries = []
sentiment = []
if options.classifySeverity or options.includeRisk:
classifications = model_service.classify_text(data)
if options.extractEntities:
entities = model_service.extract_entities(data)
if options.generateSummary:
summaries = model_service.generate_summary(data)
sentiment = model_service.analyze_sentiment(data)
response.nlp = NLPResult(
classifications=classifications,
entities=entities,
summaries=summaries,
sentiment=sentiment,
)
# Trend Analysis
if options.analyzeTrends:
logger.info(f"Running trend analysis...")
by_airline = {}
by_hub = {}
by_category = {}
for item in data:
airline = item.get("Airlines", "Unknown")
hub = item.get("HUB", "Unknown")
category = item.get("Irregularity_Complain_Category", "Unknown")
# Airline aggregation
if airline not in by_airline:
by_airline[airline] = {"count": 0, "issues": []}
by_airline[airline]["count"] += 1
by_airline[airline]["issues"].append(category)
# Hub aggregation
if hub not in by_hub:
by_hub[hub] = {"count": 0, "issues": []}
by_hub[hub]["count"] += 1
by_hub[hub]["issues"].append(category)
# Category aggregation
if category not in by_category:
by_category[category] = {"count": 0}
by_category[category]["count"] += 1
# Convert to TrendData format
by_airline_trend = {
k: TrendData(
count=v["count"],
avgResolutionDays=2.0 + np.random.random(),
topIssues=list(set(v["issues"]))[:3],
)
for k, v in by_airline.items()
}
by_hub_trend = {
k: TrendData(
count=v["count"],
avgResolutionDays=2.0 + np.random.random(),
topIssues=list(set(v["issues"]))[:3],
)
for k, v in by_hub.items()
}
by_category_trend = {
k: {"count": v["count"], "trend": "stable"}
for k, v in by_category.items()
}
response.trends = TrendResult(
byAirline=by_airline_trend,
byHub=by_hub_trend,
byCategory=by_category_trend,
timeSeries=[],
)
# Risk Assessment
if options.includeRisk:
try:
from data.action_service import get_action_service
action_service = get_action_service()
eff = action_service.action_effectiveness or {}
except Exception:
eff = {}
action_service = None
items: List[RiskItem] = []
for i, item in enumerate(data):
cat = item.get("Irregularity_Complain_Category", "Unknown") or "Unknown"
hub = item.get("HUB", "Unknown") or "Unknown"
area = (item.get("Area", "Unknown") or "Unknown").replace(" Area", "")
pr = predictions[i] if i < len(predictions) else None
cl = classifications[i] if i < len(classifications) else None
sev = cl.severity if cl else "Low"
sev_conf = cl.severityConfidence if cl else 0.6
pdays = pr.predictedDays if pr else 0.0
anom = pr.anomalyDetection.anomalyScore if pr and pr.anomalyDetection else 0.0
sev_s = _severity_to_score(sev)
d_s = _normalize_days(pdays)
cat_w = 1.0 - float(eff.get(cat, 0.8))
risk = min(1.0, 0.5 * sev_s + 0.25 * d_s + 0.15 * anom + 0.10 * cat_w)
recs: List[Dict[str, Any]] = []
if action_service:
try:
recs_resp = action_service.recommend(
report=item.get("Report", "") or "",
issue_type=cat,
severity=sev,
area=area if area else None,
airline=item.get("Airlines") or None,
top_n=5,
)
recs = recs_resp.get("recommendations", [])
except Exception:
recs = []
prev = _extract_prevention([r.get("action", "") for r in recs])
items.append(
RiskItem(
reportId=f"row_{i}",
severity=sev,
severityConfidence=sev_conf,
predictedDays=pdays,
anomalyScore=anom,
category=cat,
hub=hub,
area=area,
riskScore=round(risk, 3),
priority=_priority_from_score(risk),
recommendedActions=recs[:5],
preventiveSuggestions=prev,
)
)
groups: Dict[str, Dict[str, Any]] = {}
for it, raw in zip(items, data):
key = f"{it.category}|{it.hub}|{it.area}"
g = groups.get(key) or {"key": key, "category": it.category, "hub": it.hub, "area": it.area, "count": 0, "avgRisk": 0.0, "avgDays": 0.0, "highSeverityShare": 0.0}
g["count"] += 1
g["avgRisk"] += it.riskScore
g["avgDays"] += it.predictedDays
g["highSeverityShare"] += 1.0 if it.severity in ("Critical", "High") else 0.0
groups[key] = g
patterns = []
for g in groups.values():
c = g["count"]
g["avgRisk"] = round(g["avgRisk"] / max(1, c), 3)
g["avgDays"] = round(g["avgDays"] / max(1, c), 2)
g["highSeverityShare"] = round(g["highSeverityShare"] / max(1, c), 3)
patterns.append(g)
patterns.sort(key=lambda x: (-x["avgRisk"], -x["highSeverityShare"], -x["avgDays"], -x["count"]))
response.risk = RiskAssessmentResponse(
items=sorted(items, key=lambda x: -x.riskScore),
topPatterns=patterns[:5],
metadata={"count": len(items)},
)
if compact:
if response.regression and response.regression.predictions:
for p in response.regression.predictions:
p.shapExplanation = None
p.anomalyDetection = None
if response.nlp:
response.nlp.entities = []
response.nlp.summaries = []
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
response.metadata.processingTime = round(processing_time, 2)
logger.info(f"Analysis completed in {processing_time:.2f}ms")
return response
@app.post(
"/api/ai/analyze",
response_model=AnalysisResponse,
tags=["Analysis"],
summary="Analyze Irregularity Reports",
response_description="Analysis results including predictions, severity, and entities.",
)
async def analyze_reports(request: AnalysisRequest, compact: bool = False):
"""
Perform comprehensive AI analysis on a batch of irregularity reports.
- **Regression**: Predicts days to resolve based on category and description.
- **NLP**: Classifies severity, extracts entities (Flight No, Airline), and summarizes text.
- **Trends**: Aggregates data by Airline, Hub, and Category.
The endpoint accepts a list of `IrregularityReport` objects.
"""
try:
# Use direct data
if not request.data:
raise HTTPException(
status_code=400,
detail="No data provided. Either sheetId or data must be specified.",
)
# Convert IrregularityReport objects to dicts
data = [report.model_dump(exclude_none=True) for report in request.data]
return perform_analysis(data, request.options, compact)
except HTTPException:
raise
except Exception as e:
logger.error(f"Analysis error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
def background_analysis_task(job_id: str, data: List[Dict], options: AnalysisOptions, compact: bool):
"""Background task for analysis"""
try:
job_service.update_job(job_id, JobStatus.PROCESSING)
response = perform_analysis(data, options, compact)
job_service.update_job(job_id, JobStatus.COMPLETED, result=response.model_dump())
except Exception as e:
logger.error(f"Job {job_id} failed: {e}")
job_service.update_job(job_id, JobStatus.FAILED, error=str(e))
@app.post(
"/api/ai/analyze-async",
response_model=Dict[str, str],
tags=["Analysis", "Jobs"],
summary="Start Async Analysis Job",
)
async def analyze_async(
request: AnalysisRequest, background_tasks: BackgroundTasks, compact: bool = False
):
"""
Start a background analysis job for large datasets.
Returns a `jobId` immediately, which can be used to poll status via `/api/ai/jobs/{jobId}`.
"""
if not request.data:
raise HTTPException(status_code=400, detail="No data provided")
data = [report.model_dump(exclude_none=True) for report in request.data]
job_id = job_service.create_job()
background_tasks.add_task(background_analysis_task, job_id, data, request.options, compact)
return {"job_id": job_id, "status": "queued"}
@app.get(
"/api/ai/jobs/{job_id}",
tags=["Jobs"],
summary="Get Job Status",
)
async def get_job_status(job_id: str):
"""
Retrieve the status and results of a background analysis job.
Possible statuses: `queued`, `processing`, `completed`, `failed`.
"""
job = job_service.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return job
@app.post(
"/api/ai/predict-single",
tags=["Analysis"],
summary="Real-time Single Prediction",
)
async def predict_single(report: IrregularityReport):
"""
Get immediate AI predictions for a single irregularity report.
Useful for real-time validation or "what-if" analysis in the UI.
"""
try:
report_dict = report.model_dump(exclude_none=True)
predictions = model_service.predict_regression([report_dict])
classifications = model_service.classify_text([report_dict])
entities = model_service.extract_entities([report_dict])
summaries = model_service.generate_summary([report_dict])
sentiment = model_service.analyze_sentiment([report_dict])
return {
"prediction": predictions[0],
"classification": classifications[0],
"entities": entities[0],
"summary": summaries[0],
"sentiment": sentiment[0],
"modelLoaded": model_service.model_loaded,
}
except Exception as e:
logger.error(f"Single prediction error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/api/ai/train",
tags=["Training"],
summary="Trigger Model Retraining",
)
async def train_models(background_tasks: BackgroundTasks, force: bool = False):
"""
Trigger a background task to retrain AI models.
Checks if new data is available in Google Sheets before training, unless `force=True`.
"""
from scripts.scheduled_training import TrainingScheduler
def run_training_task():
scheduler = TrainingScheduler()
result = scheduler.run_training(force=force)
logger.info(f"Training completed: {result}")
background_tasks.add_task(run_training_task)
return {
"status": "training_queued",
"message": "Model retraining has been started in the background",
"force": force,
"timestamp": datetime.now().isoformat(),
}
@app.get(
"/api/ai/train/status",
tags=["Training"],
summary="Get Training Status",
)
async def training_status():
"""
Get the status of the latest training job and training history.
"""
from scripts.scheduled_training import TrainingScheduler
scheduler = TrainingScheduler()
status = scheduler.get_status()
return {
"status": "success",
"data": status,
"timestamp": datetime.now().isoformat(),
}
@app.get("/api/ai/model-info")
async def model_info():
"""Get current model information"""
return {
"regression": {
"version": model_service.regression_version,
"type": "GradientBoostingRegressor",
"status": "loaded" if model_service.model_loaded else "unavailable",
"last_trained": "2025-01-15",
"metrics": model_service.model_metrics
if model_service.model_loaded
else None,
},
"nlp": {
"version": model_service.nlp_version,
"type": "Rule-based + Keyword extraction",
"status": "active",
"tasks": ["classification", "ner", "summarization", "sentiment"],
"note": "Full ML NLP models coming soon",
},
}
@app.post("/api/ai/cache/invalidate")
async def invalidate_cache(sheet_name: Optional[str] = None):
"""Invalidate cache for sheets data"""
cache = get_cache()
if sheet_name:
pattern = f"sheets:*{sheet_name}*"
deleted = cache.delete_pattern(pattern)
return {
"status": "success",
"message": f"Invalidated cache for sheet: {sheet_name}",
"keys_deleted": deleted,
}
else:
deleted = cache.delete_pattern("sheets:*")
return {
"status": "success",
"message": "Invalidated all sheets cache",
"keys_deleted": deleted,
}
@app.get("/api/ai/cache/status")
async def cache_status():
"""Get cache status and statistics"""
cache = get_cache()
return cache.health_check()
class AnalyzeAllResponse(BaseModel):
status: str
metadata: Dict[str, Any]
sheets: Dict[str, Any]
results: List[Dict[str, Any]]
summary: Dict[str, Any]
timestamp: str
@app.get("/api/ai/analyze-all", response_model=AnalyzeAllResponse)
async def analyze_all_sheets(
bypass_cache: bool = False,
include_regression: bool = True,
include_nlp: bool = True,
include_trends: bool = True,
max_rows_per_sheet: int = 10000,
compact: bool = False,
):
"""
Analyze ALL rows from all Google Sheets
Fetches data from both NON CARGO and CGO sheets, analyzes each row,
and returns comprehensive results.
Args:
bypass_cache: Skip cache and fetch fresh data
include_regression: Include regression predictions
include_nlp: Include NLP analysis (severity, entities, summary)
include_trends: Include trend analysis
max_rows_per_sheet: Maximum rows to process per sheet
"""
start_time = datetime.now()
try:
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(
status_code=500, detail="GOOGLE_SHEET_ID not configured"
)
all_data = []
sheet_info = {}
sheets_to_fetch = [
{"name": "NON CARGO", "range": f"A1:AA{max_rows_per_sheet + 1}"},
{"name": "CGO", "range": f"A1:Z{max_rows_per_sheet + 1}"},
]
for sheet in sheets_to_fetch:
try:
sheet_name = sheet["name"]
range_str = sheet["range"]
logger.info(f"Fetching {sheet_name}...")
data = sheets_service.fetch_sheet_data(
spreadsheet_id, sheet_name, range_str, bypass_cache=bypass_cache
)
for row in data:
row["_source_sheet"] = sheet_name
all_data.append(row)
sheet_info[sheet_name] = {
"rows_fetched": len(data),
"status": "success",
}
except Exception as e:
logger.error(f"Failed to fetch {sheet['name']}: {e}")
sheet_info[sheet["name"]] = {
"rows_fetched": 0,
"status": "error",
"error": str(e),
}
total_records = len(all_data)
if total_records == 0:
raise HTTPException(status_code=404, detail="No data found in any sheet")
logger.info(f"Analyzing {total_records} total records...")
results = []
batch_size = 100
for i in range(0, total_records, batch_size):
batch = all_data[i : i + batch_size]
if include_regression:
regression_preds = model_service.predict_regression(batch)
else:
regression_preds = [None] * len(batch)
if include_nlp:
classifications = model_service.classify_text(batch)
entities = model_service.extract_entities(batch)
summaries = model_service.generate_summary(batch)
sentiments = model_service.analyze_sentiment(batch)
else:
classifications = [None] * len(batch)
entities = [None] * len(batch)
summaries = [None] * len(batch)
sentiments = [None] * len(batch)
for j, row in enumerate(batch):
result = {
"rowId": row.get("_row_id", f"row_{i + j}"),
"sourceSheet": row.get("_source_sheet", "Unknown"),
"originalData": {
"date": row.get("Date_of_Event"),
"airline": row.get("Airlines"),
"flightNumber": row.get("Flight_Number"),
"branch": row.get("Branch"),
"hub": row.get("HUB"),
"route": row.get("Route"),
"category": row.get("Report_Category"),
"issueType": row.get("Irregularity_Complain_Category"),
"report": row.get("Report"),
"status": row.get("Status"),
},
}
if regression_preds[j]:
pred = {
"predictedDays": regression_preds[j].predictedDays,
"confidenceInterval": regression_preds[j].confidenceInterval,
"hasUnknownCategories": regression_preds[j].hasUnknownCategories,
}
if not compact:
pred["shapExplanation"] = (
regression_preds[j].shapExplanation.model_dump()
if regression_preds[j].shapExplanation
else None
)
pred["anomalyDetection"] = (
regression_preds[j].anomalyDetection.model_dump()
if regression_preds[j].anomalyDetection
else None
)
result["prediction"] = pred
if classifications[j]:
result["classification"] = classifications[j].model_dump()
if entities[j] and not compact:
result["entities"] = entities[j].model_dump()
if summaries[j] and not compact:
result["summary"] = summaries[j].model_dump()
if sentiments[j] and not compact:
result["sentiment"] = sentiments[j].model_dump()
results.append(result)
summary = {
"totalRecords": total_records,
"sheetsProcessed": len(
[s for s in sheet_info.values() if s["status"] == "success"]
),
"regressionEnabled": include_regression,
"nlpEnabled": include_nlp,
}
if include_nlp and results:
severity_counts = {}
for r in results:
sev = r.get("classification", {}).get("severity", "Unknown")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
summary["severityDistribution"] = severity_counts
if include_regression and results:
predictions = [
r["prediction"]["predictedDays"] for r in results if r.get("prediction")
]
if predictions:
summary["predictionStats"] = {
"min": round(min(predictions), 2),
"max": round(max(predictions), 2),
"mean": round(sum(predictions) / len(predictions), 2),
}
processing_time = (datetime.now() - start_time).total_seconds()
return AnalyzeAllResponse(
status="success",
metadata={
"totalRecords": total_records,
"processingTimeSeconds": round(processing_time, 2),
"recordsPerSecond": round(total_records / processing_time, 2)
if processing_time > 0
else 0,
"modelVersions": {
"regression": model_service.regression_version,
"nlp": model_service.nlp_version,
},
},
sheets=sheet_info,
results=results,
summary=summary,
timestamp=datetime.now().isoformat(),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Analyze all error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ============== Risk Scoring Endpoints ==============
@app.get("/api/ai/risk/summary")
async def risk_summary():
"""Get overall risk summary for all entities"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_risk_summary()
@app.get("/api/ai/risk/airlines")
async def airline_risks():
"""Get risk scores for all airlines"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_airline_risks()
@app.get("/api/ai/risk/airlines/{airline_name}")
async def airline_risk(airline_name: str):
"""Get risk score for a specific airline"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
risk_data = risk_service.get_airline_risk(airline_name)
if not risk_data:
raise HTTPException(
status_code=404, detail=f"Airline '{airline_name}' not found"
)
recommendations = risk_service.get_risk_recommendations("airline", airline_name)
return {
"airline": airline_name,
"risk_data": risk_data,
"recommendations": recommendations,
}
@app.get("/api/ai/risk/branches")
async def branch_risks():
"""Get risk scores for all branches"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_branch_risks()
@app.get("/api/ai/risk/hubs")
async def hub_risks():
"""Get risk scores for all hubs"""
from data.risk_service import get_risk_service
risk_service = get_risk_service()
return risk_service.get_all_hub_risks()
@app.post("/api/ai/risk/calculate")
async def calculate_risk_scores(bypass_cache: bool = False):
"""Calculate risk scores from current Google Sheets data"""
from data.risk_service import get_risk_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch all data
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
risk_service = get_risk_service()
risk_data = risk_service.calculate_all_risk_scores(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"risk_summary": risk_service.get_risk_summary(),
}
# ============== Subcategory Classification Endpoints ==============
@app.post("/api/ai/subcategory")
async def classify_subcategory(
report: str,
area: Optional[str] = None,
issue_type: Optional[str] = None,
root_cause: Optional[str] = None,
):
"""Classify report into subcategory"""
from data.subcategory_service import get_subcategory_classifier
classifier = get_subcategory_classifier()
result = classifier.classify(report, area, issue_type, root_cause)
return result
@app.get("/api/ai/subcategory/categories")
async def get_subcategories(area: Optional[str] = None):
"""Get list of available subcategories"""
from data.subcategory_service import get_subcategory_classifier
classifier = get_subcategory_classifier()
return classifier.get_available_categories(area)
# ============== Action Recommendation Endpoints ==============
@app.post("/api/ai/action/recommend")
async def recommend_actions(
report: str,
issue_type: str,
severity: str = "Medium",
area: Optional[str] = None,
airline: Optional[str] = None,
top_n: int = 5,
):
"""Get action recommendations for an issue"""
from data.action_service import get_action_service
action_service = get_action_service()
recommendations = action_service.recommend(
report=report,
issue_type=issue_type,
severity=severity,
area=area,
airline=airline,
top_n=top_n,
)
return recommendations
@app.post("/api/ai/action/train")
async def train_action_recommender(
bypass_cache: bool = False, background_tasks: BackgroundTasks = None
):
"""Train action recommender from historical data"""
from data.action_service import get_action_service
from data.sheets_service import GoogleSheetsService
from data.similarity_service import get_similarity_service
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
similarity_service = get_similarity_service()
similarity_service.build_index(all_data)
action_service = get_action_service()
action_service.train_from_data(all_data)
return {
"status": "success",
"records_processed": len(all_data),
}
# ============== Advanced NER Endpoints ==============
@app.post("/api/ai/ner/extract")
async def extract_entities(text: str):
"""Extract entities from text"""
from data.advanced_ner_service import get_advanced_ner
ner = get_advanced_ner()
entities = ner.extract(text)
summary = ner.extract_summary(text)
return {
"entities": entities,
"summary": summary,
}
# ============== Similarity Endpoints ==============
@app.post("/api/ai/similar")
async def find_similar_reports(
text: str,
top_k: int = 5,
threshold: float = 0.3,
):
"""Find similar reports"""
from data.similarity_service import get_similarity_service
similarity_service = get_similarity_service()
similar = similarity_service.find_similar(text, top_k, threshold)
return {
"query_preview": text[:100],
"similar_reports": similar,
}
@app.post("/api/ai/similar/build-index")
async def build_similarity_index(bypass_cache: bool = False):
"""Build similarity index from Google Sheets data"""
from data.similarity_service import get_similarity_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
similarity_service = get_similarity_service()
similarity_service.build_index(all_data)
return {
"status": "success",
"records_indexed": len(all_data),
}
# ============== Forecasting Endpoints ==============
@app.get("/api/ai/forecast/issues")
async def forecast_issues(periods: int = 4):
"""Forecast issue volume for next periods"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
forecast = forecast_service.forecast_issues(periods)
return forecast
@app.get("/api/ai/forecast/trends")
async def predict_trends():
"""Predict category trends"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
trends = forecast_service.predict_category_trends()
return trends
@app.get("/api/ai/forecast/seasonal")
async def get_seasonal_patterns():
"""Get seasonal patterns"""
from data.forecast_service import get_forecast_service
forecast_service = get_forecast_service()
patterns = forecast_service.get_seasonal_patterns()
return patterns
@app.post("/api/ai/forecast/build")
async def build_forecast_data(bypass_cache: bool = False):
"""Build forecast historical data from Google Sheets"""
from data.forecast_service import get_forecast_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
forecast_service = get_forecast_service()
forecast_service.build_historical_data(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"forecast_summary": forecast_service.get_forecast_summary(),
}
# ============== Report Generation Endpoints ==============
@app.post("/api/ai/report/generate")
async def generate_report(
row_id: str,
bypass_cache: bool = False,
):
"""Generate formal incident report"""
from data.report_generator_service import get_report_generator
from data.sheets_service import GoogleSheetsService
from data.risk_service import get_risk_service
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch all data and find the record
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
record = None
for r in all_data:
if r.get("_row_id") == row_id:
record = r
break
if not record:
raise HTTPException(status_code=404, detail=f"Record '{row_id}' not found")
# Generate analysis
report_text = record.get("Report", "") + " " + record.get("Root_Caused", "")
analysis = {
"severity": model_service._classify_severity_fallback([report_text])[0].get(
"severity", "Medium"
),
"issueType": record.get("Irregularity_Complain_Category", ""),
}
# Get risk data
risk_service = get_risk_service()
airline = record.get("Airlines", "")
risk_data = risk_service.get_airline_risk(airline)
# Generate report
report_gen = get_report_generator()
formal_report = report_gen.generate_incident_report(record, analysis, risk_data)
exec_summary = report_gen.generate_executive_summary(record, analysis)
json_report = report_gen.generate_json_report(record, analysis, risk_data)
return {
"row_id": row_id,
"formal_report": formal_report,
"executive_summary": exec_summary,
"structured_report": json_report,
}
# ============== Dashboard Endpoints ==============
@app.get("/api/ai/dashboard/summary")
async def dashboard_summary(bypass_cache: bool = False):
"""Get comprehensive dashboard summary"""
from data.risk_service import get_risk_service
from data.forecast_service import get_forecast_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
# Fetch data
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA2000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z2000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
# Get risk summary
risk_service = get_risk_service()
risk_summary = risk_service.get_risk_summary()
# Get forecast summary
forecast_service = get_forecast_service()
forecast_summary = forecast_service.get_forecast_summary()
# Calculate statistics
severity_dist = Counter()
category_dist = Counter()
airline_dist = Counter()
for record in all_data:
report_text = record.get("Report", "") + " " + record.get("Root_Caused", "")
sev = model_service._classify_severity_fallback([report_text])[0].get(
"severity", "Low"
)
severity_dist[sev] += 1
category_dist[record.get("Irregularity_Complain_Category", "Unknown")] += 1
airline_dist[record.get("Airlines", "Unknown")] += 1
return {
"total_records": len(all_data),
"sheets": {
"non_cargo": len(non_cargo),
"cargo": len(cargo),
},
"severity_distribution": dict(severity_dist),
"category_distribution": dict(category_dist.most_common(10)),
"top_airlines": dict(airline_dist.most_common(10)),
"risk_summary": risk_summary,
"forecast_summary": forecast_summary,
"model_status": {
"regression": model_service.model_loaded,
"nlp": model_service.nlp_service is not None,
},
"last_updated": datetime.now().isoformat(),
}
# ============== Seasonality Endpoints ==============
@app.get("/api/ai/seasonality/summary")
async def seasonality_summary(category_type: Optional[str] = None):
"""
Get seasonality summary and patterns
Args:
category_type: "landside_airside", "cgo", or None for both
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.get_seasonality_summary(category_type)
@app.get("/api/ai/seasonality/forecast")
async def seasonality_forecast(
category_type: Optional[str] = None,
periods: int = 4,
granularity: str = "weekly",
):
"""
Forecast issue volumes
Args:
category_type: "landside_airside", "cgo", or None for both
periods: Number of periods to forecast
granularity: "daily", "weekly", or "monthly"
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.forecast(category_type, periods, granularity)
@app.get("/api/ai/seasonality/peaks")
async def seasonality_peaks(
category_type: Optional[str] = None, threshold: float = 1.2
):
"""
Identify peak periods
Args:
category_type: "landside_airside", "cgo", or None for both
threshold: Multiplier above average (1.2 = 20% above)
"""
from data.seasonality_service import get_seasonality_service
service = get_seasonality_service()
return service.get_peak_periods(category_type, threshold)
@app.post("/api/ai/seasonality/build")
async def build_seasonality_patterns(bypass_cache: bool = False):
"""Build seasonality patterns from Google Sheets data"""
from data.seasonality_service import get_seasonality_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
service = get_seasonality_service()
result = service.build_patterns(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"patterns": result,
}
# ============== Root Cause Endpoints ==============
@app.post("/api/ai/root-cause/classify")
async def classify_root_cause(
root_cause: str,
report: Optional[str] = None,
area: Optional[str] = None,
category: Optional[str] = None,
):
"""
Classify a root cause text into categories
Categories: Equipment Failure, Staff Competency, Process/Procedure,
Communication, External Factors, Documentation, Training Gap, Resource/Manpower
"""
from data.root_cause_service import get_root_cause_service
service = get_root_cause_service()
context = {"area": area, "category": category}
result = service.classify(root_cause, report or "", context)
return result
@app.post("/api/ai/root-cause/classify-batch")
async def classify_root_cause_batch(bypass_cache: bool = False):
"""Classify root causes for all records"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
service = get_root_cause_service()
results = service.classify_batch(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"classifications": results[:100],
"total_classified": len(
[r for r in results if r["primary_category"] != "Unknown"]
),
}
@app.get("/api/ai/root-cause/categories")
async def get_root_cause_categories():
"""Get all available root cause categories"""
from data.root_cause_service import get_root_cause_service
service = get_root_cause_service()
return service.get_categories()
@app.get("/api/ai/root-cause/stats")
async def get_root_cause_stats(source: Optional[str] = None, bypass_cache: bool = False):
"""
Get root cause statistics from data
Args:
source: "NON CARGO", "CGO", or None for both
bypass_cache: Skip cache and fetch fresh data
"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
all_data = []
# Conditional fetching based on source to reduce I/O and processing
if not source or source.upper() == "NON CARGO":
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
all_data.extend(non_cargo)
if not source or source.upper() == "CGO":
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data.extend(cargo)
service = get_root_cause_service()
stats = service.get_statistics(all_data)
return stats
@app.post("/api/ai/root-cause/train")
async def train_root_cause_classifier(background_tasks: BackgroundTasks, bypass_cache: bool = False):
"""Train root cause classifier from historical data"""
from data.root_cause_service import get_root_cause_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
all_data = non_cargo + cargo
service = get_root_cause_service()
# Offload the intensive training process to the background
background_tasks.add_task(service.train_from_data, all_data)
return {
"status": "training_started",
"records_fetched": len(all_data),
"message": "Classification training is now running in the background. The model will be automatically updated once complete."
}
# ============== Category Summarization Endpoints ==============
class CategorySummaryResponse(BaseModel):
status: str
category_type: str
summary: Dict[str, Any]
timestamp: str
@app.get("/api/ai/summarize", response_model=CategorySummaryResponse)
async def summarize_by_category(category: str = "all", bypass_cache: bool = False):
"""
Get summarized insights for Non-cargo and/or CGO categories
Query Parameters:
category: "non_cargo", "cgo", or "all" (default: "all")
bypass_cache: Skip cache and fetch fresh data (default: false)
Returns aggregated summary including:
- Severity distribution
- Top categories, airlines, hubs, branches
- Key insights and recommendations
- Common issues
- Monthly trends
"""
from data.category_summarization_service import get_category_summarization_service
from data.sheets_service import GoogleSheetsService
valid_categories = ["non_cargo", "cgo", "all"]
if category.lower() not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}",
)
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
summarization_service = get_category_summarization_service()
summary = summarization_service.summarize_category(all_data, category.lower())
return CategorySummaryResponse(
status="success",
category_type=category.lower(),
summary=summary,
timestamp=datetime.now().isoformat(),
)
@app.get("/api/ai/summarize/non-cargo")
async def summarize_non_cargo(bypass_cache: bool = False):
"""Quick endpoint for Non-cargo summary"""
return await summarize_by_category(category="non_cargo", bypass_cache=bypass_cache)
@app.get("/api/ai/summarize/cgo")
async def summarize_cgo(bypass_cache: bool = False):
"""Quick endpoint for CGO (Cargo) summary"""
return await summarize_by_category(category="cgo", bypass_cache=bypass_cache)
@app.get("/api/ai/summarize/compare")
async def compare_categories(bypass_cache: bool = False):
"""Compare Non-cargo and CGO categories side by side"""
return await summarize_by_category(category="all", bypass_cache=bypass_cache)
# ============== Branch Analytics Endpoints ==============
@app.get("/api/ai/branch/summary")
async def branch_analytics_summary(category_type: Optional[str] = None):
"""
Get branch analytics summary
Args:
category_type: "landside_airside", "cgo", or None for both
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_summary(category_type)
@app.get("/api/ai/branch/{branch_name}")
async def get_branch_metrics(branch_name: str, category_type: Optional[str] = None):
"""
Get metrics for a specific branch
Args:
branch_name: Branch name
category_type: "landside_airside", "cgo", or None for combined
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
data = service.get_branch(branch_name, category_type)
if not data:
raise HTTPException(status_code=404, detail=f"Branch '{branch_name}' not found")
return data
@app.get("/api/ai/branch/ranking")
async def branch_ranking(
category_type: Optional[str] = None,
sort_by: str = "risk_score",
limit: int = 20,
):
"""
Get branch ranking
Args:
category_type: "landside_airside", "cgo", or None for both
sort_by: Field to sort by (risk_score, total_issues, critical_high_count)
limit: Maximum branches to return
"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_ranking(category_type, sort_by, limit)
@app.get("/api/ai/branch/comparison")
async def branch_comparison():
"""Compare all branches across category types"""
from data.branch_analytics_service import get_branch_analytics_service
service = get_branch_analytics_service()
return service.get_comparison()
@app.post("/api/ai/branch/calculate")
async def calculate_branch_metrics(bypass_cache: bool = False):
"""Calculate branch metrics from Google Sheets data"""
from data.branch_analytics_service import get_branch_analytics_service
from data.sheets_service import GoogleSheetsService
cache = get_cache() if not bypass_cache else None
sheets_service = GoogleSheetsService(cache=cache)
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise HTTPException(status_code=500, detail="GOOGLE_SHEET_ID not configured")
non_cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA5000", bypass_cache=bypass_cache
)
cargo = sheets_service.fetch_sheet_data(
spreadsheet_id, "CGO", "A1:Z5000", bypass_cache=bypass_cache
)
for row in non_cargo:
row["_sheet_name"] = "NON CARGO"
for row in cargo:
row["_sheet_name"] = "CGO"
all_data = non_cargo + cargo
service = get_branch_analytics_service()
result = service.calculate_branch_metrics(all_data)
return {
"status": "success",
"records_processed": len(all_data),
"metrics": result,
}
# ============== Main ==============
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("API_PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)