Spaces:
Sleeping
Sleeping
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| import json | |
| import asyncio | |
| import logging | |
| from scraper import NovelCoolScraper | |
| from tts import TTSEngine | |
| import traceback | |
| from contextlib import asynccontextmanager | |
| import time | |
| # Serialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| try: | |
| logger.info("Initializing TTS Engine...") | |
| try: | |
| import onnxruntime as ort | |
| logger.info(f"ONNX Runtime providers: {ort.get_available_providers()}") | |
| except Exception: | |
| pass | |
| app.state.tts = TTSEngine() | |
| logger.info("TTS Engine initialized.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize TTS Engine: {e}") | |
| app.state.tts = None | |
| app.state.scraper = NovelCoolScraper() | |
| app.state.novel_index_cache = {} | |
| yield | |
| # Shutdown | |
| app.state.tts = None | |
| app.state.scraper = None | |
| app.state.novel_index_cache = None | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace Space landing page & info endpoint | |
| # --------------------------------------------------------------------------- | |
| def _space_runtime_info(request: Request) -> dict: | |
| tts = getattr(app.state, "tts", None) | |
| voices = [] | |
| if tts: | |
| try: | |
| voices = tts.list_voices() | |
| except Exception: | |
| voices = [] | |
| scheme = request.headers.get("x-forwarded-proto", request.url.scheme) or "https" | |
| host = request.headers.get("x-forwarded-host", request.headers.get("host", request.url.netloc)) | |
| ws_scheme = "wss" if scheme == "https" else "ws" | |
| return { | |
| "name": "CoreReader Backend", | |
| "status": "running", | |
| "tts_ready": tts is not None, | |
| "voice_count": len(voices), | |
| "sample_rate": getattr(tts, "sample_rate", None) if tts else None, | |
| "model_path": getattr(tts, "model_path", None) if tts else None, | |
| "voices_path": getattr(tts, "voices_path", None) if tts else None, | |
| "endpoints": { | |
| "health": "/health", | |
| "voices": "/voices", | |
| "novel_index": "/novel_index?url=<novel_url>", | |
| "novel_details": "/novel_details?url=<novel_url>", | |
| "novel_meta": "/novel_meta?url=<novel_url>", | |
| "novel_chapter": "/novel_chapter?url=<novel_url>&n=<chapter_number>", | |
| "websocket": "/ws", | |
| "openapi_docs": "/docs", | |
| "openapi_json": "/openapi.json", | |
| }, | |
| "frontend_base_url": f"{ws_scheme}://{host}", | |
| "frontend_ws_url": f"{ws_scheme}://{host}/ws", | |
| } | |
| async def root(request: Request): | |
| info = _space_runtime_info(request) | |
| html = f""" | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <title>CoreReader Backend</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 2rem; line-height: 1.45; }} | |
| code {{ background: #f5f5f5; padding: 0.15rem 0.35rem; border-radius: 4px; }} | |
| .ok {{ color: #0b7a2a; }} | |
| .card {{ border: 1px solid #ddd; border-radius: 10px; padding: 1rem; max-width: 880px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="card"> | |
| <h1>CoreReader Backend</h1> | |
| <p class="ok"><strong>Status:</strong> {info["status"]}</p> | |
| <p><strong>TTS Ready:</strong> {info["tts_ready"]}</p> | |
| <p><strong>Voices Loaded:</strong> {info["voice_count"]}</p> | |
| <p><strong>Sample Rate:</strong> {info["sample_rate"]}</p> | |
| <p><strong>Model:</strong> <code>{info["model_path"]}</code></p> | |
| <p><strong>Voices File:</strong> <code>{info["voices_path"]}</code></p> | |
| <h3>Use this in Frontend</h3> | |
| <p><strong>WebSocket base URL:</strong> <code>{info["frontend_base_url"]}</code></p> | |
| <p><strong>WebSocket endpoint:</strong> <code>{info["frontend_ws_url"]}</code></p> | |
| <h3>API Endpoints</h3> | |
| <ul> | |
| <li><code>/health</code></li> | |
| <li><code>/voices</code></li> | |
| <li><code>/novel_index?url=<novel_url></code></li> | |
| <li><code>/novel_details?url=<novel_url></code></li> | |
| <li><code>/novel_meta?url=<novel_url></code></li> | |
| <li><code>/novel_chapter?url=<novel_url>&n=<chapter_number></code></li> | |
| <li><code>/ws</code> (WebSocket)</li> | |
| <li><code>/docs</code> (interactive API docs)</li> | |
| <li><code>/info</code> (JSON details)</li> | |
| </ul> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |
| async def info(request: Request): | |
| return _space_runtime_info(request) | |
| async def health(): | |
| return {"ok": True, "tts_ready": app.state.tts is not None} | |
| async def voices(): | |
| if not app.state.tts: | |
| return {"voices": [], "error": "TTS Engine not initialized"} | |
| return {"voices": app.state.tts.list_voices()} | |
| async def novel_index(url: str): | |
| if not url: | |
| return {"chapters": [], "error": "url is required"} | |
| chapters = await app.state.scraper.scrape_novel_index(url) | |
| return {"chapters": chapters} | |
| async def novel_details(url: str): | |
| if not url: | |
| return {"title": None, "cover_url": None, "error": "url is required"} | |
| details = await app.state.scraper.scrape_novel_details(url) | |
| return details | |
| async def _get_cached_novel_index(novel_url: str): | |
| """Return cached chapter list for a novel URL, scraping once per TTL.""" | |
| if not novel_url: | |
| raise HTTPException(status_code=400, detail="url is required") | |
| cache = app.state.novel_index_cache | |
| if cache is None: | |
| cache = {} | |
| app.state.novel_index_cache = cache | |
| ttl_s = 30 * 60 # 30 minutes | |
| now = time.monotonic() | |
| entry = cache.get(novel_url) | |
| if entry is not None: | |
| age = now - float(entry.get("ts", 0.0)) | |
| if age < ttl_s: | |
| return entry.get("chapters") or [] | |
| chapters = await app.state.scraper.scrape_novel_index(novel_url) | |
| cache[novel_url] = {"ts": now, "chapters": chapters} | |
| return chapters | |
| async def novel_meta(url: str): | |
| chapters = await _get_cached_novel_index(url) | |
| max_n = 0 | |
| for c in chapters: | |
| try: | |
| n = c.get("n") if isinstance(c, dict) else None | |
| if isinstance(n, int) and n > max_n: | |
| max_n = n | |
| except Exception: | |
| pass | |
| return {"count": max_n if max_n > 0 else len(chapters)} | |
| async def novel_chapter(url: str, n: int): | |
| chapters = await _get_cached_novel_index(url) | |
| # Prefer resolving by parsed chapter number, not list position. | |
| resolved: dict | None = None | |
| max_n = 0 | |
| for c in chapters: | |
| if not isinstance(c, dict): | |
| continue | |
| cn = c.get("n") | |
| if isinstance(cn, int) and cn > max_n: | |
| max_n = cn | |
| if isinstance(cn, int) and cn == n: | |
| resolved = c | |
| break | |
| limit = max_n if max_n > 0 else len(chapters) | |
| if n < 1 or n > limit: | |
| raise HTTPException(status_code=400, detail=f"chapter n must be between 1 and {limit}") | |
| if resolved is None: | |
| # Fallback: old positional behavior. | |
| item = chapters[n - 1] if (n - 1) < len(chapters) else {} | |
| else: | |
| item = resolved | |
| return {"n": n, "title": item.get("title"), "url": item.get("url")} | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| cancel_event = asyncio.Event() | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| try: | |
| message = json.loads(data) | |
| command = message.get("command") | |
| if command == "scrape": | |
| url = message.get("url") | |
| if not url: | |
| await websocket.send_json({"error": "URL is required"}) | |
| continue | |
| logger.info(f"Scraping URL: {url}") | |
| try: | |
| result = await app.state.scraper.scrape_chapter(url) | |
| await websocket.send_json({"type": "scrape_result", "data": result}) | |
| except Exception as e: | |
| logger.error(f"Scrape error: {e}") | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| elif command == "tts": | |
| text = message.get("text") | |
| voice = message.get("voice", "af_bella") | |
| speed = message.get("speed", 1.0) | |
| if not text: | |
| await websocket.send_json({"error": "Text is required"}) | |
| continue | |
| logger.info(f"Streaming TTS for text length: {len(text)}") | |
| if not app.state.tts: | |
| await websocket.send_json({"error": "TTS Engine not initialized"}) | |
| continue | |
| # Ensure voice is valid for the loaded voice pack. | |
| try: | |
| available = app.state.tts.list_voices() | |
| if available and voice not in available: | |
| voice = available[0] | |
| except Exception: | |
| pass | |
| # Stream audio | |
| try: | |
| async for _, audio_chunk in app.state.tts.generate_audio_stream( | |
| text, | |
| voice=voice, | |
| speed=float(speed), | |
| prefetch_sentences=3, | |
| frame_ms=200, | |
| cancel_event=cancel_event, | |
| ): | |
| await websocket.send_bytes(audio_chunk) | |
| await websocket.send_json({"type": "tts_complete"}) | |
| except Exception as e: | |
| logger.error(f"TTS error: {e}") | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| elif command == "play": | |
| # Single-shot: scrape the chapter, then stream it sentence-by-sentence. | |
| url = message.get("url") | |
| voice = message.get("voice", "af_bella") | |
| speed = float(message.get("speed", 1.0)) | |
| prefetch = int(message.get("prefetch", 3)) | |
| frame_ms = int(message.get("frame_ms", 200)) | |
| start_paragraph = int(message.get("start_paragraph", 0) or 0) | |
| realtime = bool(message.get("realtime", True)) | |
| if not url: | |
| await websocket.send_json({"type": "error", "message": "URL is required"}) | |
| continue | |
| if not app.state.tts: | |
| await websocket.send_json({"type": "error", "message": "TTS Engine not initialized"}) | |
| continue | |
| cancel_event.clear() | |
| paused = False | |
| logger.info(f"Play request: url={url} voice={voice} speed={speed}") | |
| # Ensure voice is valid for the loaded voice pack. | |
| try: | |
| available = app.state.tts.list_voices() | |
| if available and voice not in available: | |
| voice = available[0] | |
| except Exception: | |
| pass | |
| try: | |
| chapter = await app.state.scraper.scrape_chapter(url) | |
| except Exception as e: | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| continue | |
| title = chapter.get("title") | |
| paragraphs = chapter.get("content") or [] | |
| if start_paragraph < 0: | |
| start_paragraph = 0 | |
| if start_paragraph > len(paragraphs): | |
| start_paragraph = max(0, len(paragraphs) - 1) | |
| paragraphs_slice = paragraphs[start_paragraph:] if start_paragraph else paragraphs | |
| # Provide total sentence count up-front for download/progress UIs. | |
| try: | |
| sentence_total = len(app.state.tts.split_paragraphs(paragraphs_slice)) | |
| except Exception: | |
| sentence_total = None | |
| await websocket.send_json( | |
| { | |
| "type": "chapter_info", | |
| "title": title, | |
| "url": url, | |
| "voice": voice, | |
| "next_url": chapter.get("next_url"), | |
| "prev_url": chapter.get("prev_url"), | |
| "paragraphs": paragraphs, | |
| "start_paragraph": start_paragraph, | |
| "sentence_total": sentence_total, | |
| "audio": { | |
| "encoding": "pcm_s16le", | |
| "sample_rate": app.state.tts.sample_rate, | |
| "channels": 1, | |
| # For backward-compatibility, keep frame_ms but note that | |
| # the stream is now sentence-chunked. | |
| "frame_ms": frame_ms, | |
| "chunking": "sentence", | |
| }, | |
| } | |
| ) | |
| last_key = None | |
| cumulative_samples = 0 | |
| sample_rate = app.state.tts.sample_rate | |
| # For downloads, accumulate PCM to encode as FLAC at the end. | |
| download_pcm_chunks: list[bytes] = [] if not realtime else [] | |
| try: | |
| control_task: asyncio.Task[str] | None = asyncio.create_task(websocket.receive_text()) | |
| stream_t0 = time.monotonic() | |
| async def handle_control_payload(payload: str) -> None: | |
| nonlocal paused | |
| try: | |
| msg = json.loads(payload) | |
| except json.JSONDecodeError: | |
| return | |
| cmd = msg.get("command") | |
| if cmd == "pause": | |
| paused = True | |
| elif cmd == "resume": | |
| paused = False | |
| elif cmd == "stop": | |
| cancel_event.set() | |
| async for p_idx, s_idx, sentence, audio_chunk, cs, ce in app.state.tts.generate_audio_stream_paragraphs_sentence_chunks( | |
| paragraphs_slice, | |
| voice=voice, | |
| speed=speed, | |
| prefetch_sentences=prefetch, | |
| cancel_event=cancel_event, | |
| ): | |
| # Consume any pending control messages without concurrent receives. | |
| if control_task is not None and control_task.done(): | |
| try: | |
| await handle_control_payload(control_task.result()) | |
| except WebSocketDisconnect: | |
| cancel_event.set() | |
| control_task = asyncio.create_task(websocket.receive_text()) | |
| if paused and control_task is not None: | |
| control_task.cancel() | |
| control_task = None | |
| while paused and not cancel_event.is_set(): | |
| # Block until we get a control message. | |
| try: | |
| payload = await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| cancel_event.set() | |
| break | |
| await handle_control_payload(payload) | |
| if not paused and not cancel_event.is_set() and control_task is None: | |
| control_task = asyncio.create_task(websocket.receive_text()) | |
| if cancel_event.is_set(): | |
| break | |
| key = (p_idx + start_paragraph, s_idx, sentence) | |
| if key != last_key: | |
| last_key = key | |
| ms_start = (cumulative_samples * 1000) // sample_rate | |
| await websocket.send_json( | |
| { | |
| "type": "sentence", | |
| "text": sentence, | |
| "paragraph_index": int(p_idx + start_paragraph), | |
| "sentence_index": int(s_idx), | |
| "ms_start": ms_start, | |
| "char_start": int(cs), | |
| "char_end": int(ce), | |
| # Size of the *next* binary message for this sentence in samples/bytes. | |
| # Helps clients associate metadata with audio even if transport splits chunks. | |
| "chunk_samples": int(len(audio_chunk) // 2), | |
| "chunk_bytes": int(len(audio_chunk)), | |
| } | |
| ) | |
| await websocket.send_bytes(audio_chunk) | |
| cumulative_samples += len(audio_chunk) // 2 | |
| # Accumulate PCM for FLAC encoding (downloads only). | |
| if not realtime: | |
| download_pcm_chunks.append(audio_chunk) | |
| # Optional realtime pacing. | |
| # - streaming: send roughly in-time to reduce client buffer bloat. | |
| # - downloads: realtime=false sends as fast as synthesis allows. | |
| if realtime: | |
| expected_s = cumulative_samples / float(sample_rate) | |
| elapsed_s = time.monotonic() - stream_t0 | |
| # Let the stream run slightly ahead to avoid stutter from | |
| # small scheduling/network jitter. | |
| ahead_s = 0.10 | |
| sleep_s = (expected_s - elapsed_s) - ahead_s | |
| if sleep_s > 0: | |
| await asyncio.sleep(min(sleep_s, 0.25)) | |
| # Properly clean up the control_task to avoid | |
| # concurrent recv race with the outer message loop. | |
| pending_command = None | |
| if control_task is not None: | |
| if control_task.done(): | |
| try: | |
| pending_command = control_task.result() | |
| except Exception: | |
| pass | |
| else: | |
| control_task.cancel() | |
| try: | |
| await control_task | |
| except (asyncio.CancelledError, Exception): | |
| pass | |
| control_task = None | |
| # For downloads, encode accumulated PCM as FLAC and send. | |
| if not realtime and download_pcm_chunks and not cancel_event.is_set(): | |
| try: | |
| all_pcm = b"".join(download_pcm_chunks) | |
| flac_data = app.state.tts.encode_pcm16_to_flac( | |
| all_pcm, sample_rate=sample_rate | |
| ) | |
| is_flac = flac_data[:4] == b"fLaC" | |
| await websocket.send_json({ | |
| "type": "flac_data", | |
| "encoding": "flac" if is_flac else "pcm_s16le", | |
| "size": len(flac_data), | |
| "sample_rate": sample_rate, | |
| }) | |
| await websocket.send_bytes(flac_data) | |
| except Exception as e: | |
| logger.warning(f"FLAC encoding failed, downloads saved as PCM: {e}") | |
| finally: | |
| download_pcm_chunks.clear() | |
| try: | |
| await websocket.send_json( | |
| { | |
| "type": "chapter_complete", | |
| "next_url": chapter.get("next_url"), | |
| "prev_url": chapter.get("prev_url"), | |
| } | |
| ) | |
| except Exception: | |
| pass # Client already disconnected | |
| # If the client sent a new command while streaming, | |
| # it will be picked up by the outer while-loop. | |
| if pending_command is not None: | |
| try: | |
| json.loads(pending_command) # validate | |
| except Exception: | |
| pending_command = None | |
| except Exception as e: | |
| logger.error(f"Play stream error: {e}") | |
| try: | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| except Exception: | |
| pass # Client already disconnected | |
| else: | |
| await websocket.send_json({"error": "Unknown command"}) | |
| except json.JSONDecodeError: | |
| try: | |
| await websocket.send_json({"error": "Invalid JSON"}) | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.error(f"Error processing message: {e}") | |
| traceback.print_exc() | |
| try: | |
| await websocket.send_json({"error": "Internal server error"}) | |
| except Exception: | |
| pass | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |