|
|
""" |
|
|
Hugging Face Spaces 入口點 |
|
|
基於 main.py 但針對 Hugging Face Spaces 進行優化 |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Query, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel |
|
|
from contextlib import asynccontextmanager |
|
|
import edge_tts |
|
|
import asyncio |
|
|
import os |
|
|
import uuid |
|
|
from typing import Optional, List |
|
|
import aiofiles |
|
|
import json |
|
|
from urllib.parse import urlparse |
|
|
import tempfile |
|
|
import shutil |
|
|
|
|
|
|
|
|
OUTPUT_DIR = tempfile.mkdtemp(prefix="edge_tts_") |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
print("===== Application Startup =====") |
|
|
yield |
|
|
|
|
|
print("===== Application Shutdown =====") |
|
|
try: |
|
|
if os.path.exists(OUTPUT_DIR): |
|
|
shutil.rmtree(OUTPUT_DIR) |
|
|
print("✅ 清理臨時文件完成") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 清理臨時文件時發生錯誤: {e}") |
|
|
|
|
|
app = FastAPI( |
|
|
title="Edge TTS API", |
|
|
description="A web service for text-to-speech using Microsoft Edge TTS", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
ALLOWED_ORIGINS = [ |
|
|
"https://www.dfes.ntpc.edu.tw", |
|
|
"https://demo.dfes.ntpc.edu.tw", |
|
|
"https://dfes.great-site.net", |
|
|
"https://10-241-216-5.dfes.direct.quickconnect.to" |
|
|
] |
|
|
|
|
|
|
|
|
ALLOWED_GAS_PROJECT_IDS = [ |
|
|
"m2azgihk7rtcqqaxywntv75at3thf3bgps4xdlq", |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("static"): |
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
|
text: str |
|
|
voice: Optional[str] = "zh-TW-HsiaoChenNeural" |
|
|
rate: Optional[str] = "+0%" |
|
|
volume: Optional[str] = "+0%" |
|
|
pitch: Optional[str] = "+0Hz" |
|
|
|
|
|
class TTSResponse(BaseModel): |
|
|
success: bool |
|
|
message: str |
|
|
audio_url: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
def _is_origin_allowed(request: Request) -> bool: |
|
|
"""檢查請求來源是否被允許""" |
|
|
origin = request.headers.get("origin") |
|
|
referer = request.headers.get("referer") |
|
|
user_agent = request.headers.get("user-agent", "") |
|
|
|
|
|
print(f"請求來源檢查 - Origin: {origin}, Referer: {referer}, User-Agent: {user_agent}") |
|
|
|
|
|
|
|
|
if origin: |
|
|
for allowed_origin in ALLOWED_ORIGINS: |
|
|
if origin.startswith(allowed_origin): |
|
|
print(f"✅ Origin 匹配: {origin} 匹配 {allowed_origin}") |
|
|
return True |
|
|
|
|
|
|
|
|
if _is_gas_project_allowed(origin): |
|
|
return True |
|
|
|
|
|
if referer: |
|
|
for allowed_origin in ALLOWED_ORIGINS: |
|
|
if referer.startswith(allowed_origin): |
|
|
print(f"✅ Referer 匹配: {referer} 匹配 {allowed_origin}") |
|
|
return True |
|
|
|
|
|
|
|
|
if _is_gas_project_allowed(referer): |
|
|
return True |
|
|
|
|
|
|
|
|
if "huggingface" in user_agent.lower(): |
|
|
print("✅ 檢測到 Hugging Face Spaces 環境") |
|
|
return True |
|
|
|
|
|
print(f"❌ 請求被拒絕 - Origin: {origin}, Referer: {referer}") |
|
|
return False |
|
|
|
|
|
def _is_gas_project_allowed(url: str) -> bool: |
|
|
"""檢查是否為允許的 GAS 專案""" |
|
|
if not url: |
|
|
return False |
|
|
|
|
|
|
|
|
if not url.startswith("https://n-") or not url.endswith("-script.googleusercontent.com"): |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
project_part = url.replace("https://n-", "").replace("-script.googleusercontent.com", "") |
|
|
|
|
|
|
|
|
last_dash_index = project_part.rfind("-") |
|
|
if last_dash_index == -1: |
|
|
return False |
|
|
|
|
|
project_id = project_part[:last_dash_index] |
|
|
|
|
|
|
|
|
if project_id in ALLOWED_GAS_PROJECT_IDS: |
|
|
print(f"✅ GAS 專案匹配: {project_id} 在允許清單中") |
|
|
return True |
|
|
else: |
|
|
print(f"❌ GAS 專案不匹配: {project_id} 不在允許清單中") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ 解析 GAS 網址時發生錯誤: {e}") |
|
|
return False |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""根路徑,返回 API 信息""" |
|
|
return { |
|
|
"message": "Edge TTS API Service - Hugging Face Spaces", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"GET /voices": "獲取所有可用語音", |
|
|
"POST /tts": "文字轉語音", |
|
|
"GET /tts": "文字轉語音 (GET 方法)", |
|
|
"GET /health": "健康檢查", |
|
|
"GET /allowed-origins": "查看允許的來源網址", |
|
|
"GET /debug-request": "調試請求信息" |
|
|
}, |
|
|
"note": "此服務部署在 Hugging Face Spaces 上,僅允許特定來源網址訪問" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""健康檢查端點""" |
|
|
return {"status": "healthy", "service": "edge-tts-api", "platform": "huggingface-spaces"} |
|
|
|
|
|
@app.get("/allowed-origins") |
|
|
async def get_allowed_origins(): |
|
|
"""獲取允許的來源列表(僅供管理員查看)""" |
|
|
return { |
|
|
"allowed_origins": ALLOWED_ORIGINS, |
|
|
"allowed_gas_project_ids": ALLOWED_GAS_PROJECT_IDS, |
|
|
"count": len(ALLOWED_ORIGINS) + len(ALLOWED_GAS_PROJECT_IDS), |
|
|
"description": "允許的來源網址清單,支援擴充匹配。GAS 專案 ID 會自動匹配所有使用者的網址變體。" |
|
|
} |
|
|
|
|
|
@app.get("/debug-request") |
|
|
async def debug_request(request: Request): |
|
|
"""調試端點:顯示請求的詳細信息""" |
|
|
headers = dict(request.headers) |
|
|
origin = headers.get("origin") |
|
|
referer = headers.get("referer") |
|
|
|
|
|
|
|
|
gas_project_info = {} |
|
|
if origin: |
|
|
gas_project_info["origin"] = _extract_gas_project_id(origin) |
|
|
if referer: |
|
|
gas_project_info["referer"] = _extract_gas_project_id(referer) |
|
|
|
|
|
return { |
|
|
"method": request.method, |
|
|
"url": str(request.url), |
|
|
"headers": headers, |
|
|
"origin": origin, |
|
|
"referer": referer, |
|
|
"user_agent": headers.get("user-agent"), |
|
|
"is_allowed": _is_origin_allowed(request), |
|
|
"client_ip": request.client.host if request.client else "unknown", |
|
|
"gas_project_info": gas_project_info, |
|
|
"allowed_gas_project_ids": ALLOWED_GAS_PROJECT_IDS, |
|
|
"request_info": { |
|
|
"has_origin": bool(headers.get("origin")), |
|
|
"has_referer": bool(headers.get("referer")), |
|
|
"is_huggingface": "huggingface" in headers.get("user-agent", "").lower(), |
|
|
"api_key_provided": "x-api-key" in headers |
|
|
} |
|
|
} |
|
|
|
|
|
def _extract_gas_project_id(url: str) -> dict: |
|
|
"""提取 GAS 專案 ID 信息""" |
|
|
if not url: |
|
|
return {"is_gas": False} |
|
|
|
|
|
if not url.startswith("https://n-") or not url.endswith("-script.googleusercontent.com"): |
|
|
return {"is_gas": False} |
|
|
|
|
|
try: |
|
|
project_part = url.replace("https://n-", "").replace("-script.googleusercontent.com", "") |
|
|
last_dash_index = project_part.rfind("-") |
|
|
|
|
|
if last_dash_index == -1: |
|
|
return {"is_gas": True, "project_id": None, "user_id": None, "error": "無法解析專案 ID"} |
|
|
|
|
|
project_id = project_part[:last_dash_index] |
|
|
user_id = project_part[last_dash_index + 1:] |
|
|
|
|
|
return { |
|
|
"is_gas": True, |
|
|
"project_id": project_id, |
|
|
"user_id": user_id, |
|
|
"is_allowed": project_id in ALLOWED_GAS_PROJECT_IDS |
|
|
} |
|
|
except Exception as e: |
|
|
return {"is_gas": True, "error": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/voices") |
|
|
async def get_voices(): |
|
|
"""獲取所有可用的語音列表""" |
|
|
try: |
|
|
voices = await edge_tts.list_voices() |
|
|
return { |
|
|
"success": True, |
|
|
"voices": voices, |
|
|
"count": len(voices) |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"獲取語音列表失敗: {str(e)}") |
|
|
|
|
|
@app.post("/tts", response_model=TTSResponse) |
|
|
async def text_to_speech(request: TTSRequest, http_request: Request): |
|
|
"""文字轉語音 API (POST 方法)""" |
|
|
try: |
|
|
|
|
|
if not _is_origin_allowed(http_request): |
|
|
return TTSResponse( |
|
|
success=False, |
|
|
message="來源網站未被允許使用此 API", |
|
|
error="forbidden" |
|
|
) |
|
|
|
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
output_file = os.path.join(OUTPUT_DIR, f"{file_id}.mp3") |
|
|
|
|
|
|
|
|
communicate = edge_tts.Communicate( |
|
|
text=request.text, |
|
|
voice=request.voice, |
|
|
rate=request.rate, |
|
|
volume=request.volume, |
|
|
pitch=request.pitch |
|
|
) |
|
|
|
|
|
|
|
|
await communicate.save(output_file) |
|
|
|
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
raise Exception("語音文件生成失敗") |
|
|
|
|
|
return TTSResponse( |
|
|
success=True, |
|
|
message="語音生成成功", |
|
|
audio_url=f"/audio/{file_id}.mp3" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return TTSResponse( |
|
|
success=False, |
|
|
message="語音生成失敗", |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
@app.get("/tts") |
|
|
async def text_to_speech_get( |
|
|
text: str = Query(..., description="要轉換的文字"), |
|
|
voice: str = Query("zh-TW-HsiaoChenNeural", description="語音名稱"), |
|
|
rate: str = Query("+0%", description="語速調整"), |
|
|
volume: str = Query("+0%", description="音量調整"), |
|
|
pitch: str = Query("+0Hz", description="音調調整"), |
|
|
http_request: Request = None |
|
|
): |
|
|
"""文字轉語音 API (GET 方法)""" |
|
|
try: |
|
|
print(f"TTS 請求: text={text}, voice={voice}, rate={rate}, volume={volume}, pitch={pitch}") |
|
|
|
|
|
|
|
|
if http_request and not _is_origin_allowed(http_request): |
|
|
raise HTTPException( |
|
|
status_code=403, |
|
|
detail="來源網站未被允許使用此 API" |
|
|
) |
|
|
|
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
output_file = os.path.join(OUTPUT_DIR, f"{file_id}.mp3") |
|
|
print(f"輸出文件路徑: {output_file}") |
|
|
|
|
|
|
|
|
communicate = edge_tts.Communicate( |
|
|
text=text, |
|
|
voice=voice, |
|
|
rate=rate, |
|
|
volume=volume, |
|
|
pitch=pitch |
|
|
) |
|
|
|
|
|
|
|
|
print("開始生成語音文件...") |
|
|
await communicate.save(output_file) |
|
|
print("語音文件生成完成") |
|
|
|
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
print(f"文件不存在: {output_file}") |
|
|
raise HTTPException(status_code=500, detail="語音文件生成失敗") |
|
|
|
|
|
file_size = os.path.getsize(output_file) |
|
|
print(f"文件大小: {file_size} bytes") |
|
|
|
|
|
|
|
|
return FileResponse( |
|
|
output_file, |
|
|
media_type="audio/mpeg", |
|
|
filename=f"tts_{file_id}.mp3" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"TTS 錯誤: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"語音生成失敗: {str(e)}") |
|
|
|
|
|
@app.get("/audio/{file_id}.mp3") |
|
|
async def get_audio_file(file_id: str): |
|
|
"""獲取生成的音頻文件""" |
|
|
file_path = os.path.join(OUTPUT_DIR, f"{file_id}.mp3") |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
raise HTTPException(status_code=404, detail="音頻文件不存在") |
|
|
|
|
|
return FileResponse( |
|
|
file_path, |
|
|
media_type="audio/mpeg", |
|
|
filename=f"tts_{file_id}.mp3" |
|
|
) |
|
|
|
|
|
@app.delete("/audio/{file_id}.mp3") |
|
|
async def delete_audio_file(file_id: str): |
|
|
"""刪除音頻文件""" |
|
|
file_path = os.path.join(OUTPUT_DIR, f"{file_id}.mp3") |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
raise HTTPException(status_code=404, detail="音頻文件不存在") |
|
|
|
|
|
try: |
|
|
os.remove(file_path) |
|
|
return {"success": True, "message": "文件刪除成功"} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"文件刪除失敗: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |