import time import json import secrets import asyncio import httpx from pathlib import Path from datetime import datetime, timedelta from fastapi import FastAPI, Request, Header, Depends, HTTPException, status from fastapi.responses import JSONResponse, Response, FileResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from pydantic import BaseModel from typing import Optional, List from config import Config from cache_manager import cache from user_manager import user_manager, User, AVAILABLE_BADGES from proxy_handler import ( proxy_media, proxy_live_stream_direct, proxy_playback_stream, get_live_m3u8_url ) from utils import get_auth, get_channels, get_jst_date, fetch_epg, get_all_epg app = FastAPI( title=Config.APP_NAME, version=Config.APP_VERSION, description=Config.APP_DESCRIPTION ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["Content-Length", "Content-Range", "Accept-Ranges", "Content-Disposition"] ) if Config.ENABLE_GZIP: app.add_middleware(GZipMiddleware, minimum_size=1000) static_path = Path(__file__).parent / "static" if static_path.exists(): app.mount("/static", StaticFiles(directory=str(static_path)), name="static") admin_tokens = {} def create_admin_token() -> str: token = secrets.token_urlsafe(32) expiry = datetime.now() + timedelta(hours=24) admin_tokens[token] = expiry return token def verify_admin_token(token: str) -> bool: if not token: return False now = datetime.now() expired = [t for t, exp in admin_tokens.items() if exp < now] for t in expired: del admin_tokens[t] if token not in admin_tokens: return False expiry = admin_tokens[token] now = datetime.now() if now > expiry: del admin_tokens[token] return False return True def get_admin_token(authorization: Optional[str]) -> Optional[str]: if not authorization: return None if authorization.startswith("Bearer "): return authorization[7:] return authorization def get_current_admin_token(authorization: Optional[str] = Header(None)) -> str: token = get_admin_token(authorization) if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="No token provided" ) if not verify_admin_token(token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token" ) return token class PasswordVerify(BaseModel): username: str password_hash: str class AdminLogin(BaseModel): username: str password_hash: str class CreateUserRequest(BaseModel): username: str password: Optional[str] = None expires_days: Optional[int] = None notes: str = "" badge: Optional[str] = None is_admin: bool = False class ExtendExpiryRequest(BaseModel): days: int class SetBadgeRequest(BaseModel): badge: Optional[str] = None class UserSettings(BaseModel): favorite_channels: Optional[List[str]] = None playback_history: Optional[dict] = None program_reminders: Optional[List[dict]] = None download_concurrency: Optional[int] = None batch_download_concurrency: Optional[int] = None fab_position: Optional[dict] = None other_settings: Optional[dict] = None @app.middleware("http") async def protocol_middleware(request: Request, call_next): forwarded_proto = request.headers.get('X-Forwarded-Proto', '') forwarded_host = request.headers.get('X-Forwarded-Host', '') forwarded_port = request.headers.get('X-Forwarded-Port', '') if forwarded_proto: request.scope['scheme'] = forwarded_proto if forwarded_host: port = 443 if forwarded_proto == 'https' else 80 if forwarded_port: try: port = int(forwarded_port) except: pass request.scope['server'] = (forwarded_host, port) response = await call_next(request) return response @app.middleware("http") async def performance_middleware(request: Request, call_next): start_time = time.time() response = await call_next(request) process_time = int((time.time() - start_time) * 1000) response.headers["X-Response-Time"] = f"{process_time}ms" if request.url.path.startswith('/static/'): response.headers['Cache-Control'] = 'public, max-age=86400' if request.url.path.startswith('/api/') or request.url.path.startswith('/live/') or request.url.path.startswith('/vod/'): response.headers['Access-Control-Allow-Origin'] = '*' response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS, DELETE' response.headers['Access-Control-Allow-Headers'] = 'Authorization, Content-Type, Range' return response @app.get("/") async def root(): html_path = Path(__file__).parent / "static" / "index.html" if html_path.exists(): return FileResponse(html_path) return {"message": "Frontend not found"} @app.get("/channels") async def channels_page(): return await root() @app.get("/player") async def player_page(): return await root() @app.get("/epg") async def epg_page(): return await root() @app.get("/cache") async def cache_page(): return await root() @app.get("/api-test") async def api_test_page(): return await root() @app.get("/admin") async def admin_page(): html_path = Path(__file__).parent / "static" / "admin.html" if html_path.exists(): return FileResponse(html_path) return {"message": "Admin page not found"} @app.get("/admin/login") async def admin_login_page(): html_path = Path(__file__).parent / "static" / "admin-login.html" if html_path.exists(): return FileResponse(html_path) return {"message": "Admin login page not found"} @app.post("/api/verify-password") async def verify_password(data: PasswordVerify): try: # ✅ 检查是否是配置文件中的管理员 if (data.username == Config.ADMIN_USERNAME and data.password_hash == Config.ADMIN_PASSWORD_HASH): return { "success": True, "message": "Admin login successful", "user": { "username": data.username, "is_admin": True, # ✅ 配置文件管理员 "badge": None } } # ✅ 检查数据库中的用户 if data.username and user_manager.verify_user(data.username, data.password_hash): user = user_manager.get_user(data.username) if not user: return {"success": False, "message": "User not found"} user_data = user_manager.get_user_data(data.username) return { "success": True, "message": "User login successful", "user": { "username": data.username, "is_admin": user.is_admin, # ✅ 从数据库读取 is_admin 字段 "badge": user.badge if user and user.badge else None }, "user_data": user_data } return {"success": False, "message": "Invalid username or password"} except Exception as e: return JSONResponse( content={"success": False, "message": str(e)}, status_code=500 ) @app.get("/api/badges") async def get_badges(): return { "success": True, "badges": AVAILABLE_BADGES } @app.post("/api/admin/login") async def admin_login(data: AdminLogin): try: if (data.username == Config.ADMIN_USERNAME and data.password_hash == Config.ADMIN_PASSWORD_HASH): token = create_admin_token() return { "success": True, "token": token, "message": "Login successful" } else: return JSONResponse( content={"success": False, "message": "Invalid credentials"}, status_code=401 ) except Exception as e: return JSONResponse( content={"success": False, "message": str(e)}, status_code=500 ) @app.get("/api/admin/check") async def admin_check(authorization: Optional[str] = Header(None)): token = get_admin_token(authorization) if token and verify_admin_token(token): return {"authenticated": True} return JSONResponse( content={"authenticated": False}, status_code=401 ) @app.get("/api/admin/badges") async def admin_get_badges(token: str = Depends(get_current_admin_token)): try: return { "success": True, "badges": AVAILABLE_BADGES } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/api/admin/stats") async def admin_stats(token: str = Depends(get_current_admin_token)): try: stats = user_manager.get_stats() return stats except Exception as e: return JSONResponse( content={"error": str(e)}, status_code=500 ) @app.get("/api/admin/users") async def admin_list_users(token: str = Depends(get_current_admin_token)): try: users = user_manager.list_users() return { "success": True, "count": len(users), "users": [u.dict() for u in users] } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.post("/api/admin/users") async def admin_create_user(data: CreateUserRequest, token: str = Depends(get_current_admin_token)): try: if len(user_manager.users) >= Config.MAX_USERS: return JSONResponse( content={"error": f"Maximum {Config.MAX_USERS} users allowed"}, status_code=400 ) user, plain_password = user_manager.create_user( username=data.username, password=data.password, expires_days=data.expires_days, notes=data.notes, badge=data.badge, is_admin=data.is_admin ) return { "success": True, "user": user.dict(), "password": plain_password } except ValueError as e: return JSONResponse( content={"error": str(e)}, status_code=400 ) except Exception as e: return JSONResponse( content={"error": str(e)}, status_code=500 ) @app.delete("/api/admin/users/{username}") async def admin_delete_user(username: str, token: str = Depends(get_current_admin_token)): if user_manager.delete_user(username): # ✅ 同时删除用户设置 user_manager.delete_user_settings(username) return {"success": True, "message": f"User {username} deleted"} return JSONResponse( content={"error": "User not found"}, status_code=404 ) @app.post("/api/admin/users/{username}/activate") async def admin_activate_user(username: str, token: str = Depends(get_current_admin_token)): if user_manager.activate_user(username): return {"success": True, "message": f"User {username} activated"} return JSONResponse( content={"error": "User not found"}, status_code=404 ) @app.post("/api/admin/users/{username}/deactivate") async def admin_deactivate_user(username: str, token: str = Depends(get_current_admin_token)): if user_manager.deactivate_user(username): return {"success": True, "message": f"User {username} deactivated"} return JSONResponse( content={"error": "User not found"}, status_code=404 ) @app.post("/api/admin/users/{username}/extend") async def admin_extend_expiry(username: str, data: ExtendExpiryRequest, token: str = Depends(get_current_admin_token)): if user_manager.extend_expiry(username, data.days): return { "success": True, "message": f"Extended {username} expiry by {data.days} days" } return JSONResponse( content={"error": "User not found"}, status_code=404 ) @app.post("/api/admin/users/{username}/badge") async def admin_set_badge(username: str, data: SetBadgeRequest, token: str = Depends(get_current_admin_token)): try: if user_manager.set_badge(username, data.badge): return { "success": True, "message": f"Badge updated for {username}" } return JSONResponse( content={"error": "User not found"}, status_code=404 ) except ValueError as e: return JSONResponse( content={"error": str(e)}, status_code=400 ) except Exception as e: return JSONResponse( content={"error": str(e)}, status_code=500 ) # ==================== 用户设置API ==================== @app.get("/api/user/{username}/settings") async def get_user_settings(username: str): """获取用户设置""" print("\n" + "=" * 80) print(f"📥 [API] 收到读取请求") print(f" URL: /api/user/{username}/settings") print(f" 用户名: {username}") print("=" * 80) try: settings = user_manager.get_user_settings(username) print(f"📤 [API] 返回数据: {list(settings.keys())}") print("=" * 80 + "\n") return { "success": True, "settings": settings } except Exception as e: print(f"❌ [API] 异常: {e}") import traceback traceback.print_exc() print("=" * 80 + "\n") return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) # ==================== 用户数据同步接口(内部使用)==================== class UserDataSync(BaseModel): username: str data: dict @app.post("/api/user/data/sync") async def sync_user_data(payload: UserDataSync): """同步用户数据到 Redis(内部接口)""" print(f"📡 [SYNC] 收到用户数据同步请求: {payload.username}") print(f" 数据字段: {list(payload.data.keys())}") try: success = user_manager.update_user_data(payload.username, payload.data) if success: print(f"✅ [SYNC] 用户 {payload.username} 数据同步成功") return { "success": True, "message": "数据已实时同步到Redis" } else: print(f"❌ [SYNC] 用户 {payload.username} 不存在") return JSONResponse( content={"success": False, "error": "用户不存在"}, status_code=404 ) except Exception as e: print(f"❌ [SYNC] 同步失败: {e}") import traceback traceback.print_exc() return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) # ==================== 用户行为跟踪接口 ==================== class UserBehaviorLog(BaseModel): username: str action: str # 'play', 'download', 'favorite', 'search', 'setting_change', etc. data: dict # 相关数据 @app.post("/api/user/behavior/track") async def track_user_behavior(payload: UserBehaviorLog): """实时跟踪用户行为并保存到Redis""" print(f"📊 [BEHAVIOR] 跟踪用户行为: {payload.username} - {payload.action}") try: # 获取当前用户数据 user_data = user_manager.get_user_data(payload.username) if not user_data: return JSONResponse( content={"success": False, "error": "用户不存在"}, status_code=404 ) # 根据行为类型更新相应数据 update_data = {} if payload.action == 'play': # 更新播放历史 playback_history = user_data.get('playback_history', []) playback_entry = { 'timestamp': datetime.now().isoformat(), 'channel_id': payload.data.get('channel_id'), 'channel_name': payload.data.get('channel_name'), 'duration': payload.data.get('duration', 0) } playback_history.insert(0, playback_entry) # 保留最近100条记录 playback_history = playback_history[:100] update_data['playback_history'] = playback_history elif payload.action == 'favorite': # 更新收藏频道 favorite_channels = payload.data.get('favorite_channels', []) update_data['favorite_channels'] = favorite_channels elif payload.action == 'setting_change': # 更新设置 for key, value in payload.data.items(): if key in ['download_concurrency', 'batch_download_concurrency', 'fab_position']: update_data[key] = value elif payload.action == 'reminder': # 更新节目提醒 program_reminders = payload.data.get('program_reminders', []) update_data['program_reminders'] = program_reminders # 实时保存到Redis if update_data: success = user_manager.update_user_data(payload.username, update_data) if success: print(f"✅ [BEHAVIOR] 用户 {payload.username} 行为数据已实时保存") return { "success": True, "message": f"用户行为 '{payload.action}' 已实时保存到Redis" } return JSONResponse( content={"success": False, "error": "无效的行为数据"}, status_code=400 ) except Exception as e: print(f"❌ [BEHAVIOR] 行为跟踪失败: {e}") import traceback traceback.print_exc() return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/health") async def health_check(): stats = cache.get_stats() is_valid, missing = Config.validate() return { "name": Config.APP_NAME, "version": Config.APP_VERSION, "description": Config.APP_DESCRIPTION, "status": "running" if is_valid else "configuration_error", "config_valid": is_valid, "missing_config": missing if not is_valid else [], "password_protected": len(user_manager.users) > 0, "total_users": len(user_manager.users), "cache": { "storage_type": stats['storage_type'], "cid": stats['cid'], "auth": stats['auth'], "channels": stats['channels'], "streams": stats['streams'], "epg": stats['epg'], "epg_detail": stats.get('epg_detail') }, "features": { "streaming": True, "download": True, "live_recording": True, "recording_mode": "Frontend Sequential Recording", "user_management": True, "admin_features": True, "unified_login": True, "cache_persistence": stats['storage_type'] in ['redis', 'disk'], "user_settings_sync": True, "auto_refresh": { "cid": "1 day (auto refresh on expire)", "auth": "3 hours (auto refresh on expire or 401/403)", "storage": stats['storage_type'].upper() } } } @app.get("/api/refresh") async def refresh_cache(type: str = "all"): cache.clear_cache(type) if type in ['auth', 'all']: try: await get_auth(force=True) message = f"{type.capitalize()} cache cleared and refreshed" except Exception as e: message = f"{type.capitalize()} cache cleared, but refresh failed: {str(e)}" elif type == 'cid': try: from utils import get_cid await get_cid(force=True) message = "CID cache cleared and refreshed" except Exception as e: message = f"CID cache cleared, but refresh failed: {str(e)}" else: message = f"{type.capitalize()} cache cleared" return { "success": True, "message": message } @app.get("/api/list") async def list_channels(request: Request): try: auth = await get_auth() channels = await get_channels(auth) scheme = request.url.scheme host = request.url.netloc worker_base = f"{scheme}://{host}" rewritten_channels = [ { **ch, "playUrl": f"{worker_base}/api/live/{ch['no']}" } for ch in channels ] return { "success": True, "count": len(rewritten_channels), "channels": rewritten_channels, "cached": cache.get_channels() is not None } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/api/epg") async def get_epg(vid: str, date: str): """获取单个频道某天的EPG,优先使用缓存""" try: if not vid or not date: return JSONResponse( content={"success": False, "error": "Missing vid or date"}, status_code=400 ) auth = await get_auth() # 直接调用 fetch_epg,它会自动处理缓存 epg_data = await fetch_epg(vid, date, auth) return { "success": True, "vid": vid, "date": date, "count": len(epg_data), "epg": epg_data, "cached": cache.get_epg(vid, date) is not None } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/api/epg/all") async def get_all_epg_data(): """获取所有EPG数据,优先使用缓存""" try: auth = await get_auth() # get_all_epg 会自动处理缓存 all_epg = await get_all_epg(auth, force=False) total_channels = len(all_epg) total_programs = sum(len(programs) for programs in all_epg.values()) return { "success": True, "total_channels": total_channels, "total_programs": total_programs, "data": all_epg, "cached": cache.get_epg('_all_', 'full') is not None } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/api/epg/search") async def search_epg(keyword: str, days: int = 30): """搜索节目,快速返回结果,后台异步缓存""" try: if not keyword: return JSONResponse( content={"success": False, "error": "Missing keyword"}, status_code=400 ) auth = await get_auth() channels_list = await get_channels(auth) channel_map = {ch['id']: ch for ch in channels_list} now = datetime.now() date_list = [] for i in range(days + 1): date_obj = now - timedelta(days=i) date_str = get_jst_date(date_obj) date_list.append(date_str) results = [] keyword_lower = keyword.lower() cache_hits = 0 cache_misses = 0 # 检查是否有全量缓存 full_cache = cache.get_epg('_all_', 'full') if full_cache: # 有全量缓存,直接搜索(最快) for channel_id, programs in full_cache.items(): channel_info = channel_map.get(channel_id) if not channel_info: continue for program in programs: program_time = program.get('time', 0) program_date = get_jst_date(datetime.fromtimestamp(program_time)) if program_date not in date_list: continue title = program.get('title') or program.get('name') or '' if keyword_lower in title.lower(): results.append({ 'channel_id': channel_id, 'channel_name': channel_info['name'], 'channel_no': channel_info['no'], 'program': program, 'date': program_date }) cache_hits += 1 else: # 没有全量缓存,使用智能搜索策略 # 策略:只获取和搜索数据,不等待全部缓存完成 # 先从已有缓存中搜索 for channel_id, channel_info in channel_map.items(): for date_str in date_list: cached_epg = cache.get_epg(channel_id, date_str) if cached_epg is not None: # 从缓存中搜索 cache_hits += 1 for program in cached_epg: title = program.get('title') or program.get('name') or '' if keyword_lower in title.lower(): results.append({ 'channel_id': channel_id, 'channel_name': channel_info['name'], 'channel_no': channel_info['no'], 'program': program, 'date': date_str }) else: cache_misses += 1 # 如果没有足够的缓存,启动后台任务获取全量数据 if cache_hits == 0 or cache_misses > cache_hits: # 后台异步获取全量EPG并缓存 asyncio.create_task(background_fetch_all_epg(auth)) # 排序结果 results.sort(key=lambda x: x['program']['time'], reverse=True) return { "success": True, "keyword": keyword, "days": days, "total": len(results), "results": results, "cache_stats": { "hits": cache_hits, "misses": cache_misses, "strategy": "full_cache" if full_cache else "partial_cache", "hit_rate": f"{cache_hits * 100 // (cache_hits + cache_misses) if (cache_hits + cache_misses) > 0 else 0}%" }, "message": "后台正在缓存数据,下次搜索会更快" if not full_cache and cache_misses > 0 else None } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) async def background_fetch_all_epg(auth: dict): """后台异步任务:获取全量EPG数据""" try: # 调用 get_all_epg 来获取并缓存所有数据 await get_all_epg(auth, force=False) except Exception as e: # 静默失败,不影响用户体验 pass @app.get("/api/live/{chid}") async def live_stream_info(chid: str, request: Request): try: auth = await get_auth() channels = await get_channels(auth) channel = next((ch for ch in channels if str(ch['no']) == chid), None) if not channel: return JSONResponse( content={ "success": False, "error": f"Channel {chid} not found" }, status_code=404 ) scheme = request.url.scheme host = request.url.netloc worker_base = f"{scheme}://{host}" upstream_m3u8 = await get_live_m3u8_url(chid, auth) return { "success": True, "channel": { "id": channel['id'], "no": channel['no'], "name": channel['name'] }, "stream": { "m3u8": f"{worker_base}/stream/live/{chid}.m3u8", "direct": upstream_m3u8 }, "info": { "protocol": scheme, "cached": cache.get_stream(f"live_{chid}") is not None } } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/stream/live/{chid}.m3u8") async def live_stream_m3u8(chid: str, request: Request): return await proxy_live_stream_direct(chid, request) @app.get("/api/playback/{path:path}") async def playback_stream_info(path: str, request: Request): try: auth = await get_auth() scheme = request.url.scheme host = request.url.netloc worker_base = f"{scheme}://{host}" clean_path = path.strip('/') if clean_path.startswith('/'): clean_path = clean_path[1:] if not clean_path.startswith('query/'): if '/' not in clean_path: clean_path = f"query/{clean_path}" return { "success": True, "playback": { "path": f"/{clean_path}", "m3u8": f"{worker_base}/stream/playback/{clean_path}.m3u8", "original_path": path }, "info": { "protocol": scheme, "type": "playback" } } except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.get("/stream/playback/{path:path}.m3u8") async def playback_stream_m3u8(path: str, request: Request): return await proxy_playback_stream(path, request) @app.get("/api/download/playback/") async def download_playback_by_path( request: Request, path: str, channel: str ): try: auth = await get_auth() channels = await get_channels(auth) target_channel = None for ch in channels: if str(ch['no']) == str(channel): target_channel = ch break if not target_channel: raise ValueError(f"频道 {channel} 不存在") clean_path = path.strip() if clean_path.startswith('/'): clean_path = clean_path[1:] if clean_path.startswith('query/'): clean_path = clean_path[6:] if clean_path.endswith('.m3u8'): clean_path = clean_path[:-6] program_title = "Unknown" program_time = None found_date = None from datetime import timezone JST = timezone(timedelta(hours=9)) now_jst = datetime.now(JST) for days_ago in range(0, 30): check_date_jst = now_jst - timedelta(days=days_ago) check_date = check_date_jst.strftime('%Y-%m-%d') try: epg_list = await fetch_epg(target_channel['id'], check_date, auth) if not epg_list: continue for prog in epg_list: if prog.get('path'): prog_path = prog['path'].strip() if prog_path.startswith('/'): prog_path = prog_path[1:] if prog_path.startswith('query/'): prog_path = prog_path[6:] if prog_path.endswith('.m3u8'): prog_path = prog_path[:-6] if prog_path == clean_path: program_title = prog.get('title') or prog.get('name') or 'Unknown' program_time = datetime.fromtimestamp(prog['time'], tz=JST) found_date = check_date break if program_time: break except Exception as e: continue if not program_time: program_time = now_jst program_title = f"Playback_{target_channel['name']}" def clean_text(text): import re text = str(text).strip() forbidden_chars = r'[<>:"/\\|?*]' cleaned = re.sub(forbidden_chars, '_', text) cleaned = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', cleaned) cleaned = re.sub(r'_+', '_', cleaned) cleaned = cleaned.strip('_').strip() max_length = 150 if len(cleaned) > max_length: if '】' in cleaned[:max_length]: pos = cleaned[:max_length].rfind('】') cleaned = cleaned[:pos+1] elif '【' in cleaned[:max_length]: pos = cleaned[:max_length].rfind('【') cleaned = cleaned[:pos] else: cleaned = cleaned[:max_length] return cleaned if cleaned else "unknown" time_str = program_time.strftime('%Y%m%d_%H%M') channel_name = clean_text(target_channel['name']) program_name = clean_text(program_title) filename = f"{time_str}_{channel_name}_{program_name}.ts" playback_path = path.strip() if playback_path.startswith('/'): playback_path = playback_path[1:] if not playback_path.startswith('query/'): playback_path = f"query/{playback_path}" vod_host = Config.UPSTREAM_HOSTS['vod'] from urllib.parse import quote access_token = quote(auth['access_token']) upstream_m3u8 = f"{vod_host}/{playback_path}.m3u8?type=vod&__cross_domain_user={access_token}" headers = { 'Referer': Config.REQUIRED_REFERER, 'User-Agent': 'Mozilla/5.0' } async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.get(upstream_m3u8, headers=headers) if resp.status_code != 200: raise Exception(f"M3U8获取失败: HTTP {resp.status_code}") m3u8_content = resp.text from utils import extract_playlist_url playlist_url = extract_playlist_url(m3u8_content, upstream_m3u8) if not playlist_url or playlist_url == upstream_m3u8: playlist_content = m3u8_content playlist_url = upstream_m3u8 else: async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.get(playlist_url, headers=headers) if resp.status_code != 200: raise Exception(f"播放列表获取失败: HTTP {resp.status_code}") playlist_content = resp.text base_url = playlist_url.rsplit('/', 1)[0] ts_urls = [] for line in playlist_content.split('\n'): line = line.strip() if line and not line.startswith('#'): ts_urls.append(line if line.startswith('http') else f"{base_url}/{line}") if len(ts_urls) == 0: raise Exception("未找到TS分段") async def download_concurrent(): async def fetch_batch(client, batch, start_idx): tasks = [client.get(url, headers=headers, timeout=60.0) for url in batch] responses = await asyncio.gather(*tasks, return_exceptions=True) results = [] for i, resp in enumerate(responses): idx = start_idx + i if isinstance(resp, Exception): results.append((idx, None)) elif resp.status_code == 200: results.append((idx, resp.content)) else: results.append((idx, None)) return results batch_size = 10 all_segments = {} async with httpx.AsyncClient( timeout=60.0, limits=httpx.Limits(max_keepalive_connections=20, max_connections=30) ) as client: for i in range(0, len(ts_urls), batch_size): batch = ts_urls[i:i+batch_size] batch_results = await fetch_batch(client, batch, i) for idx, content in batch_results: if content: all_segments[idx] = content progress = min(i + batch_size, len(ts_urls)) percent = progress * 100 // len(ts_urls) for i in range(len(ts_urls)): if i in all_segments: yield all_segments[i] from urllib.parse import quote encoded_filename = quote(filename) return StreamingResponse( download_concurrent(), media_type="video/mp2t", headers={ "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", "Cache-Control": "no-cache", } ) except Exception as e: return JSONResponse( content={"success": False, "error": str(e)}, status_code=500 ) @app.options("/live/{path:path}") @app.options("/vod/{path:path}") @app.options("/query/{path:path}") @app.options("/stream/{path:path}") @app.options("/api/{path:path}") async def options_handler(): return Response( status_code=200, headers={ 'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'GET, POST, OPTIONS, DELETE', 'Access-Control-Allow-Headers': 'Authorization, Content-Type, Range', 'Access-Control-Max-Age': '3600' } ) @app.get("/live/{path:path}") async def proxy_live_media(path: str, request: Request): return await proxy_media(request, f"/live/{path}") @app.get("/vod/{path:path}") async def proxy_vod_media(path: str, request: Request): return await proxy_media(request, f"/vod/{path}") @app.get("/query/{path:path}") async def proxy_query_media(path: str, request: Request): return await proxy_media(request, f"/query/{path}") @app.exception_handler(404) async def not_found_handler(request: Request, exc): return JSONResponse( content={"error": "Not Found", "path": request.url.path}, status_code=404 ) @app.exception_handler(500) async def server_error_handler(request: Request, exc): return JSONResponse( content={"error": "Internal Server Error", "detail": "An error occurred"}, status_code=500 ) @app.on_event("startup") async def startup_event(): print("=" * 60) print("🚀 Media Gateway 启动") print("=" * 60) # 显示缓存状态 stats = cache.get_stats() print(f"📦 存储类型: {stats['storage_type'].upper()}") if stats['storage_type'] == 'redis': print(" ✅ Redis 持久化已启用") elif stats['storage_type'] == 'disk': print(f" ✅ 磁盘缓存已启用: {cache.cache_dir}") print(f" 📊 EPG 缓存: {stats.get('epg', 0)} 条") else: print(" ⚠️ 仅使用内存缓存(重启后丢失)") # 用户管理状态 if user_manager.redis: print("👥 用户数据: Redis 持久化") else: print("👥 用户数据: 内存存储") # 配置验证 is_valid, missing = Config.validate() if is_valid: print("✅ 配置验证通过") else: print(f"⚠️ 缺少配置: {', '.join(missing)}") # 预加载缓存(可选) try: print("🔄 预加载数据...") from utils import get_cid cid = await get_cid() auth = await get_auth() channels = await get_channels(auth) print(f" ✅ 频道列表: {len(channels)} 个") except Exception as e: print(f" ⚠️ 预加载失败: {e}") print("=" * 60) print("✅ 启动完成!") print("=" * 60) @app.on_event("shutdown") async def shutdown_event(): print("\n" + "=" * 60) print("🛑 Media Gateway 关闭中...") print("=" * 60) # 保存缓存 if cache.storage_type == 'disk': cache._save_to_disk(force=True) print(f"💾 磁盘缓存已保存 ({len(cache.epg)} 条 EPG)") # 保存用户数据 if not user_manager.redis and hasattr(user_manager, 'users'): print(f"💾 用户数据已保存 ({len(user_manager.users)} 个用户)") print("✅ 关闭完成") print("=" * 60) if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="error" )