Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import numpy as np | |
| import math | |
| from flask import Flask, jsonify, send_from_directory, request | |
| from flask_cors import CORS | |
| # Import the custom prediction backend module | |
| try: | |
| from prediction_backend import TrafficPredictor | |
| except ImportError: | |
| print("Warning: prediction_backend.py not found. Prediction features will be disabled.") | |
| TrafficPredictor = None | |
| except Exception as e: | |
| print(f"Warning: Failed to import prediction_backend: {e}") | |
| TrafficPredictor = None | |
| # ========================================== | |
| # Flask Server | |
| # ========================================== | |
| app = Flask(__name__, static_folder='.') | |
| CORS(app) | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Data directory path | |
| DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, 'data')) | |
| # File path configurations | |
| JSON_PATH = os.path.join(DATA_DIR, 'base2info.json') | |
| TRAFFIC_PATH = os.path.join(DATA_DIR, 'bs_record_energy_normalized_sampled.npz') | |
| SPATIAL_PATH = os.path.join(DATA_DIR, 'spatial_features.npz') | |
| MODEL_PATH = os.path.join(BASE_DIR, 'best_corr_model.pt') | |
| # ========================================== | |
| # Utility Functions | |
| # ========================================== | |
| def calculate_std_dev(records, avg): | |
| """Calculates standard deviation for a given set of records and their average.""" | |
| if not records or len(records) < 2: | |
| return 0 | |
| variance = sum((x - avg) ** 2 for x in records) / len(records) | |
| return math.sqrt(variance) | |
| def calculate_stats(data_list): | |
| """Calculate global statistics for frontend normalization""" | |
| print("Calculating statistical distribution (Avg & Std)...") | |
| avgs = [] | |
| stds = [] | |
| for item in data_list: | |
| records = item.get('bs_record', []) | |
| if records: | |
| avg = sum(records) / len(records) | |
| std = calculate_std_dev(records, avg) | |
| else: | |
| avg = 0 | |
| std = 0 | |
| avgs.append(avg) | |
| stds.append(std) | |
| def get_percentiles(values): | |
| """Calculates percentiles to create data brackets for visualization.""" | |
| values.sort() | |
| n = len(values) | |
| if n == 0: return {k:0 for k in ['min','max','t1','t2','t3','t4']} | |
| return { | |
| "min": values[0], | |
| "max": values[-1], | |
| "t1": values[int(n * 0.2)], | |
| "t2": values[int(n * 0.4)], | |
| "t3": values[int(n * 0.6)], | |
| "t4": values[int(n * 0.8)] | |
| } | |
| stats_h = get_percentiles(avgs) # Statistics for pillar heights | |
| stats_c = get_percentiles(stds) # Statistics for pillar colors (stability) | |
| return stats_h, stats_c | |
| def _convert_numpy_type(val): | |
| if isinstance(val, np.ndarray): return val.tolist() | |
| elif isinstance(val, (np.integer, np.int64, np.int32, np.int16)): return int(val) | |
| elif isinstance(val, (np.floating, np.float64, np.float32)): return float(val) | |
| elif isinstance(val, bytes): return val.decode('utf-8') | |
| else: return val | |
| def load_and_process_data(json_path, npz_path): | |
| print(f"[DataLoader] Loading basic data...") | |
| print(f" - JSON: {json_path}") | |
| print(f" - Traffic NPZ : {npz_path}") | |
| if not os.path.exists(json_path) or not os.path.exists(npz_path): | |
| print("[DataLoader] Error: Input files not found.") | |
| return [] | |
| try: | |
| npz_data = np.load(npz_path) | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| json_map = json.load(f) | |
| except Exception as e: | |
| print(f"[DataLoader] Read error: {e}") | |
| return [] | |
| # Handle binary strings if present in NPZ | |
| raw_bs_ids = npz_data['bs_id'] | |
| bs_ids = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in raw_bs_ids] | |
| num_stations = len(bs_ids) | |
| # Identify available time-series attributes in NPZ | |
| station_attributes = [] | |
| for key in npz_data.files: | |
| if key == 'bs_id': continue | |
| if npz_data[key].shape[0] == num_stations: | |
| station_attributes.append(key) | |
| merged_data = [] | |
| match_count = 0 | |
| for i in range(num_stations): | |
| current_id = bs_ids[i] | |
| json_key = f"Base_{current_id}" | |
| if json_key in json_map: | |
| match_count += 1 | |
| entry = { | |
| "id": current_id, | |
| "npz_index": i, # Store original index for prediction lookups | |
| "loc": json_map[json_key]["loc"] | |
| } | |
| for attr in station_attributes: | |
| val = npz_data[attr][i] | |
| entry[attr] = _convert_numpy_type(val) | |
| merged_data.append(entry) | |
| print(f"[DataLoader] Merge complete! Matched: {match_count}/{num_stations}") | |
| return merged_data | |
| # ========================================== | |
| # Initialization Sequence | |
| # ========================================== | |
| print("Server Initializing...") | |
| # 1. Load basic station data for frontend display | |
| ALL_DATA = load_and_process_data(JSON_PATH, TRAFFIC_PATH) | |
| STATS_HEIGHT = {} | |
| STATS_COLOR = {} | |
| if ALL_DATA: | |
| STATS_HEIGHT, STATS_COLOR = calculate_stats(ALL_DATA) | |
| else: | |
| print("⚠️ CRITICAL WARNING: Data list is empty!") | |
| # 2. Initialize AI Predictor with Spatial Features | |
| predictor = None | |
| if TrafficPredictor: | |
| try: | |
| print(f"[AI] Initializing Predictor with model: {MODEL_PATH}") | |
| # Initialize the predictor using the model and spatial feature files | |
| predictor = TrafficPredictor( | |
| model_path=MODEL_PATH, | |
| spatial_path=SPATIAL_PATH, | |
| traffic_path=TRAFFIC_PATH | |
| ) | |
| print("[AI] Predictor loaded successfully.") | |
| except Exception as e: | |
| print(f"[AI] Failed to load predictor: {e}") | |
| # ========================================== | |
| # API Routes | |
| # ========================================== | |
| def index(): | |
| """Serves the main dashboard page.""" | |
| return send_from_directory('.', 'index.html') | |
| def serve_static(path): | |
| """Serves static assets (JS, CSS, Images).""" | |
| return send_from_directory('.', path) | |
| def get_station_locations(): | |
| """Returns a lightweight list of station coordinates and statistical summaries.""" | |
| lightweight_data = [] | |
| for item in ALL_DATA: | |
| records = item.get('bs_record', []) | |
| if records: | |
| avg = sum(records) / len(records) | |
| std = calculate_std_dev(records, avg) | |
| else: | |
| avg = 0 | |
| std = 0 | |
| lightweight_data.append({ | |
| "id": item['id'], | |
| "loc": item['loc'], | |
| "val_h": avg, | |
| "val_c": std, | |
| "vals": records | |
| }) | |
| return jsonify({ | |
| "stats_height": STATS_HEIGHT, | |
| "stats_color": STATS_COLOR, | |
| "stations": lightweight_data | |
| }) | |
| def get_station_detail(station_id): | |
| """Returns detailed metadata and stats for a specific station.""" | |
| for item in ALL_DATA: | |
| if str(item['id']) == str(station_id): | |
| records = item.get('bs_record', []) | |
| avg = sum(records)/len(records) if records else 0 | |
| std = calculate_std_dev(records, avg) | |
| response = item.copy() | |
| response['stats'] = {"avg": avg, "std": std} | |
| return jsonify(response) | |
| return jsonify({"error": "Station not found"}), 404 | |
| def predict_traffic(station_id): | |
| """Triggers the ML model to predict future traffic for a specific station.""" | |
| if not predictor: | |
| return jsonify({"error": "Prediction service not available"}), 503 | |
| try: | |
| target_idx = -1 | |
| # Map Station ID to its internal index in the NPZ file | |
| for item in ALL_DATA: | |
| if str(item['id']) == str(station_id): | |
| target_idx = item.get('npz_index', -1) | |
| break | |
| if target_idx == -1: | |
| # Fallback: Check if the ID provided is directly a numerical index | |
| if str(station_id).isdigit(): | |
| target_idx = int(station_id) | |
| else: | |
| return jsonify({"error": "Station ID not found in mapping"}), 404 | |
| # Execute prediction through the ML backend | |
| result = predictor.predict(target_idx) | |
| if "error" in result: | |
| return jsonify(result), 500 | |
| return jsonify(result) | |
| except Exception as e: | |
| print(f"Prediction Error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| # FOR ONLINE | |
| if __name__ == '__main__': | |
| print(f"Monitoring Data Directory: {DATA_DIR}") | |
| print("Server running on port 7860...") | |
| app.run(host='0.0.0.0', port=7860) |