import os import json import numpy as np import math from flask import Flask, jsonify, send_from_directory, request from flask_cors import CORS # Import the custom prediction backend module try: from prediction_backend import TrafficPredictor except ImportError: print("Warning: prediction_backend.py not found. Prediction features will be disabled.") TrafficPredictor = None except Exception as e: print(f"Warning: Failed to import prediction_backend: {e}") TrafficPredictor = None # ========================================== # Flask Server # ========================================== app = Flask(__name__, static_folder='.') CORS(app) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Data directory path DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, 'data')) # File path configurations JSON_PATH = os.path.join(DATA_DIR, 'base2info.json') TRAFFIC_PATH = os.path.join(DATA_DIR, 'bs_record_energy_normalized_sampled.npz') SPATIAL_PATH = os.path.join(DATA_DIR, 'spatial_features.npz') MODEL_PATH = os.path.join(BASE_DIR, 'best_corr_model.pt') # ========================================== # Utility Functions # ========================================== def calculate_std_dev(records, avg): """Calculates standard deviation for a given set of records and their average.""" if not records or len(records) < 2: return 0 variance = sum((x - avg) ** 2 for x in records) / len(records) return math.sqrt(variance) def calculate_stats(data_list): """Calculate global statistics for frontend normalization""" print("Calculating statistical distribution (Avg & Std)...") avgs = [] stds = [] for item in data_list: records = item.get('bs_record', []) if records: avg = sum(records) / len(records) std = calculate_std_dev(records, avg) else: avg = 0 std = 0 avgs.append(avg) stds.append(std) def get_percentiles(values): """Calculates percentiles to create data brackets for visualization.""" values.sort() n = len(values) if n == 0: return {k:0 for k in ['min','max','t1','t2','t3','t4']} return { "min": values[0], "max": values[-1], "t1": values[int(n * 0.2)], "t2": values[int(n * 0.4)], "t3": values[int(n * 0.6)], "t4": values[int(n * 0.8)] } stats_h = get_percentiles(avgs) # Statistics for pillar heights stats_c = get_percentiles(stds) # Statistics for pillar colors (stability) return stats_h, stats_c def _convert_numpy_type(val): if isinstance(val, np.ndarray): return val.tolist() elif isinstance(val, (np.integer, np.int64, np.int32, np.int16)): return int(val) elif isinstance(val, (np.floating, np.float64, np.float32)): return float(val) elif isinstance(val, bytes): return val.decode('utf-8') else: return val def load_and_process_data(json_path, npz_path): print(f"[DataLoader] Loading basic data...") print(f" - JSON: {json_path}") print(f" - Traffic NPZ : {npz_path}") if not os.path.exists(json_path) or not os.path.exists(npz_path): print("[DataLoader] Error: Input files not found.") return [] try: npz_data = np.load(npz_path) with open(json_path, 'r', encoding='utf-8') as f: json_map = json.load(f) except Exception as e: print(f"[DataLoader] Read error: {e}") return [] # Handle binary strings if present in NPZ raw_bs_ids = npz_data['bs_id'] bs_ids = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in raw_bs_ids] num_stations = len(bs_ids) # Identify available time-series attributes in NPZ station_attributes = [] for key in npz_data.files: if key == 'bs_id': continue if npz_data[key].shape[0] == num_stations: station_attributes.append(key) merged_data = [] match_count = 0 for i in range(num_stations): current_id = bs_ids[i] json_key = f"Base_{current_id}" if json_key in json_map: match_count += 1 entry = { "id": current_id, "npz_index": i, # Store original index for prediction lookups "loc": json_map[json_key]["loc"] } for attr in station_attributes: val = npz_data[attr][i] entry[attr] = _convert_numpy_type(val) merged_data.append(entry) print(f"[DataLoader] Merge complete! Matched: {match_count}/{num_stations}") return merged_data # ========================================== # Initialization Sequence # ========================================== print("Server Initializing...") # 1. Load basic station data for frontend display ALL_DATA = load_and_process_data(JSON_PATH, TRAFFIC_PATH) STATS_HEIGHT = {} STATS_COLOR = {} if ALL_DATA: STATS_HEIGHT, STATS_COLOR = calculate_stats(ALL_DATA) else: print("⚠️ CRITICAL WARNING: Data list is empty!") # 2. Initialize AI Predictor with Spatial Features predictor = None if TrafficPredictor: try: print(f"[AI] Initializing Predictor with model: {MODEL_PATH}") # Initialize the predictor using the model and spatial feature files predictor = TrafficPredictor( model_path=MODEL_PATH, spatial_path=SPATIAL_PATH, traffic_path=TRAFFIC_PATH ) print("[AI] Predictor loaded successfully.") except Exception as e: print(f"[AI] Failed to load predictor: {e}") # ========================================== # API Routes # ========================================== @app.route('/') def index(): """Serves the main dashboard page.""" return send_from_directory('.', 'index.html') @app.route('/') def serve_static(path): """Serves static assets (JS, CSS, Images).""" return send_from_directory('.', path) @app.route('/api/stations/locations') def get_station_locations(): """Returns a lightweight list of station coordinates and statistical summaries.""" lightweight_data = [] for item in ALL_DATA: records = item.get('bs_record', []) if records: avg = sum(records) / len(records) std = calculate_std_dev(records, avg) else: avg = 0 std = 0 lightweight_data.append({ "id": item['id'], "loc": item['loc'], "val_h": avg, "val_c": std, "vals": records }) return jsonify({ "stats_height": STATS_HEIGHT, "stats_color": STATS_COLOR, "stations": lightweight_data }) @app.route('/api/stations/detail/') def get_station_detail(station_id): """Returns detailed metadata and stats for a specific station.""" for item in ALL_DATA: if str(item['id']) == str(station_id): records = item.get('bs_record', []) avg = sum(records)/len(records) if records else 0 std = calculate_std_dev(records, avg) response = item.copy() response['stats'] = {"avg": avg, "std": std} return jsonify(response) return jsonify({"error": "Station not found"}), 404 @app.route('/api/predict/') def predict_traffic(station_id): """Triggers the ML model to predict future traffic for a specific station.""" if not predictor: return jsonify({"error": "Prediction service not available"}), 503 try: target_idx = -1 # Map Station ID to its internal index in the NPZ file for item in ALL_DATA: if str(item['id']) == str(station_id): target_idx = item.get('npz_index', -1) break if target_idx == -1: # Fallback: Check if the ID provided is directly a numerical index if str(station_id).isdigit(): target_idx = int(station_id) else: return jsonify({"error": "Station ID not found in mapping"}), 404 # Execute prediction through the ML backend result = predictor.predict(target_idx) if "error" in result: return jsonify(result), 500 return jsonify(result) except Exception as e: print(f"Prediction Error: {e}") return jsonify({"error": str(e)}), 500 # FOR ONLINE if __name__ == '__main__': print(f"Monitoring Data Directory: {DATA_DIR}") print("Server running on port 7860...") app.run(host='0.0.0.0', port=7860)