CoreReader / backend /server.py
shreyas-joshi's picture
Fix WS recv race + session recycle 20 with async overlap
8a22550
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
import json
import asyncio
import logging
from scraper import NovelCoolScraper
from tts import TTSEngine
import traceback
from contextlib import asynccontextmanager
import time
# Serialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
try:
logger.info("Initializing TTS Engine...")
try:
import onnxruntime as ort
logger.info(f"ONNX Runtime providers: {ort.get_available_providers()}")
except Exception:
pass
app.state.tts = TTSEngine()
logger.info("TTS Engine initialized.")
except Exception as e:
logger.error(f"Failed to initialize TTS Engine: {e}")
app.state.tts = None
app.state.scraper = NovelCoolScraper()
app.state.novel_index_cache = {}
yield
# Shutdown
app.state.tts = None
app.state.scraper = None
app.state.novel_index_cache = None
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# HuggingFace Space landing page & info endpoint
# ---------------------------------------------------------------------------
def _space_runtime_info(request: Request) -> dict:
tts = getattr(app.state, "tts", None)
voices = []
if tts:
try:
voices = tts.list_voices()
except Exception:
voices = []
scheme = request.headers.get("x-forwarded-proto", request.url.scheme) or "https"
host = request.headers.get("x-forwarded-host", request.headers.get("host", request.url.netloc))
ws_scheme = "wss" if scheme == "https" else "ws"
return {
"name": "CoreReader Backend",
"status": "running",
"tts_ready": tts is not None,
"voice_count": len(voices),
"sample_rate": getattr(tts, "sample_rate", None) if tts else None,
"model_path": getattr(tts, "model_path", None) if tts else None,
"voices_path": getattr(tts, "voices_path", None) if tts else None,
"endpoints": {
"health": "/health",
"voices": "/voices",
"novel_index": "/novel_index?url=<novel_url>",
"novel_details": "/novel_details?url=<novel_url>",
"novel_meta": "/novel_meta?url=<novel_url>",
"novel_chapter": "/novel_chapter?url=<novel_url>&n=<chapter_number>",
"websocket": "/ws",
"openapi_docs": "/docs",
"openapi_json": "/openapi.json",
},
"frontend_base_url": f"{ws_scheme}://{host}",
"frontend_ws_url": f"{ws_scheme}://{host}/ws",
}
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
info = _space_runtime_info(request)
html = f"""
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>CoreReader Backend</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 2rem; line-height: 1.45; }}
code {{ background: #f5f5f5; padding: 0.15rem 0.35rem; border-radius: 4px; }}
.ok {{ color: #0b7a2a; }}
.card {{ border: 1px solid #ddd; border-radius: 10px; padding: 1rem; max-width: 880px; }}
</style>
</head>
<body>
<div class="card">
<h1>CoreReader Backend</h1>
<p class="ok"><strong>Status:</strong> {info["status"]}</p>
<p><strong>TTS Ready:</strong> {info["tts_ready"]}</p>
<p><strong>Voices Loaded:</strong> {info["voice_count"]}</p>
<p><strong>Sample Rate:</strong> {info["sample_rate"]}</p>
<p><strong>Model:</strong> <code>{info["model_path"]}</code></p>
<p><strong>Voices File:</strong> <code>{info["voices_path"]}</code></p>
<h3>Use this in Frontend</h3>
<p><strong>WebSocket base URL:</strong> <code>{info["frontend_base_url"]}</code></p>
<p><strong>WebSocket endpoint:</strong> <code>{info["frontend_ws_url"]}</code></p>
<h3>API Endpoints</h3>
<ul>
<li><code>/health</code></li>
<li><code>/voices</code></li>
<li><code>/novel_index?url=&lt;novel_url&gt;</code></li>
<li><code>/novel_details?url=&lt;novel_url&gt;</code></li>
<li><code>/novel_meta?url=&lt;novel_url&gt;</code></li>
<li><code>/novel_chapter?url=&lt;novel_url&gt;&amp;n=&lt;chapter_number&gt;</code></li>
<li><code>/ws</code> (WebSocket)</li>
<li><code>/docs</code> (interactive API docs)</li>
<li><code>/info</code> (JSON details)</li>
</ul>
</div>
</body>
</html>
"""
return HTMLResponse(content=html)
@app.get("/info")
async def info(request: Request):
return _space_runtime_info(request)
@app.get("/health")
async def health():
return {"ok": True, "tts_ready": app.state.tts is not None}
@app.get("/voices")
async def voices():
if not app.state.tts:
return {"voices": [], "error": "TTS Engine not initialized"}
return {"voices": app.state.tts.list_voices()}
@app.get("/novel_index")
async def novel_index(url: str):
if not url:
return {"chapters": [], "error": "url is required"}
chapters = await app.state.scraper.scrape_novel_index(url)
return {"chapters": chapters}
@app.get("/novel_details")
async def novel_details(url: str):
if not url:
return {"title": None, "cover_url": None, "error": "url is required"}
details = await app.state.scraper.scrape_novel_details(url)
return details
async def _get_cached_novel_index(novel_url: str):
"""Return cached chapter list for a novel URL, scraping once per TTL."""
if not novel_url:
raise HTTPException(status_code=400, detail="url is required")
cache = app.state.novel_index_cache
if cache is None:
cache = {}
app.state.novel_index_cache = cache
ttl_s = 30 * 60 # 30 minutes
now = time.monotonic()
entry = cache.get(novel_url)
if entry is not None:
age = now - float(entry.get("ts", 0.0))
if age < ttl_s:
return entry.get("chapters") or []
chapters = await app.state.scraper.scrape_novel_index(novel_url)
cache[novel_url] = {"ts": now, "chapters": chapters}
return chapters
@app.get("/novel_meta")
async def novel_meta(url: str):
chapters = await _get_cached_novel_index(url)
max_n = 0
for c in chapters:
try:
n = c.get("n") if isinstance(c, dict) else None
if isinstance(n, int) and n > max_n:
max_n = n
except Exception:
pass
return {"count": max_n if max_n > 0 else len(chapters)}
@app.get("/novel_chapter")
async def novel_chapter(url: str, n: int):
chapters = await _get_cached_novel_index(url)
# Prefer resolving by parsed chapter number, not list position.
resolved: dict | None = None
max_n = 0
for c in chapters:
if not isinstance(c, dict):
continue
cn = c.get("n")
if isinstance(cn, int) and cn > max_n:
max_n = cn
if isinstance(cn, int) and cn == n:
resolved = c
break
limit = max_n if max_n > 0 else len(chapters)
if n < 1 or n > limit:
raise HTTPException(status_code=400, detail=f"chapter n must be between 1 and {limit}")
if resolved is None:
# Fallback: old positional behavior.
item = chapters[n - 1] if (n - 1) < len(chapters) else {}
else:
item = resolved
return {"n": n, "title": item.get("title"), "url": item.get("url")}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
cancel_event = asyncio.Event()
try:
while True:
data = await websocket.receive_text()
try:
message = json.loads(data)
command = message.get("command")
if command == "scrape":
url = message.get("url")
if not url:
await websocket.send_json({"error": "URL is required"})
continue
logger.info(f"Scraping URL: {url}")
try:
result = await app.state.scraper.scrape_chapter(url)
await websocket.send_json({"type": "scrape_result", "data": result})
except Exception as e:
logger.error(f"Scrape error: {e}")
await websocket.send_json({"type": "error", "message": str(e)})
elif command == "tts":
text = message.get("text")
voice = message.get("voice", "af_bella")
speed = message.get("speed", 1.0)
if not text:
await websocket.send_json({"error": "Text is required"})
continue
logger.info(f"Streaming TTS for text length: {len(text)}")
if not app.state.tts:
await websocket.send_json({"error": "TTS Engine not initialized"})
continue
# Ensure voice is valid for the loaded voice pack.
try:
available = app.state.tts.list_voices()
if available and voice not in available:
voice = available[0]
except Exception:
pass
# Stream audio
try:
async for _, audio_chunk in app.state.tts.generate_audio_stream(
text,
voice=voice,
speed=float(speed),
prefetch_sentences=3,
frame_ms=200,
cancel_event=cancel_event,
):
await websocket.send_bytes(audio_chunk)
await websocket.send_json({"type": "tts_complete"})
except Exception as e:
logger.error(f"TTS error: {e}")
await websocket.send_json({"type": "error", "message": str(e)})
elif command == "play":
# Single-shot: scrape the chapter, then stream it sentence-by-sentence.
url = message.get("url")
voice = message.get("voice", "af_bella")
speed = float(message.get("speed", 1.0))
prefetch = int(message.get("prefetch", 3))
frame_ms = int(message.get("frame_ms", 200))
start_paragraph = int(message.get("start_paragraph", 0) or 0)
realtime = bool(message.get("realtime", True))
if not url:
await websocket.send_json({"type": "error", "message": "URL is required"})
continue
if not app.state.tts:
await websocket.send_json({"type": "error", "message": "TTS Engine not initialized"})
continue
cancel_event.clear()
paused = False
logger.info(f"Play request: url={url} voice={voice} speed={speed}")
# Ensure voice is valid for the loaded voice pack.
try:
available = app.state.tts.list_voices()
if available and voice not in available:
voice = available[0]
except Exception:
pass
try:
chapter = await app.state.scraper.scrape_chapter(url)
except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})
continue
title = chapter.get("title")
paragraphs = chapter.get("content") or []
if start_paragraph < 0:
start_paragraph = 0
if start_paragraph > len(paragraphs):
start_paragraph = max(0, len(paragraphs) - 1)
paragraphs_slice = paragraphs[start_paragraph:] if start_paragraph else paragraphs
# Provide total sentence count up-front for download/progress UIs.
try:
sentence_total = len(app.state.tts.split_paragraphs(paragraphs_slice))
except Exception:
sentence_total = None
await websocket.send_json(
{
"type": "chapter_info",
"title": title,
"url": url,
"voice": voice,
"next_url": chapter.get("next_url"),
"prev_url": chapter.get("prev_url"),
"paragraphs": paragraphs,
"start_paragraph": start_paragraph,
"sentence_total": sentence_total,
"audio": {
"encoding": "pcm_s16le",
"sample_rate": app.state.tts.sample_rate,
"channels": 1,
# For backward-compatibility, keep frame_ms but note that
# the stream is now sentence-chunked.
"frame_ms": frame_ms,
"chunking": "sentence",
},
}
)
last_key = None
cumulative_samples = 0
sample_rate = app.state.tts.sample_rate
# For downloads, accumulate PCM to encode as FLAC at the end.
download_pcm_chunks: list[bytes] = [] if not realtime else []
try:
control_task: asyncio.Task[str] | None = asyncio.create_task(websocket.receive_text())
stream_t0 = time.monotonic()
async def handle_control_payload(payload: str) -> None:
nonlocal paused
try:
msg = json.loads(payload)
except json.JSONDecodeError:
return
cmd = msg.get("command")
if cmd == "pause":
paused = True
elif cmd == "resume":
paused = False
elif cmd == "stop":
cancel_event.set()
async for p_idx, s_idx, sentence, audio_chunk, cs, ce in app.state.tts.generate_audio_stream_paragraphs_sentence_chunks(
paragraphs_slice,
voice=voice,
speed=speed,
prefetch_sentences=prefetch,
cancel_event=cancel_event,
):
# Consume any pending control messages without concurrent receives.
if control_task is not None and control_task.done():
try:
await handle_control_payload(control_task.result())
except WebSocketDisconnect:
cancel_event.set()
control_task = asyncio.create_task(websocket.receive_text())
if paused and control_task is not None:
control_task.cancel()
control_task = None
while paused and not cancel_event.is_set():
# Block until we get a control message.
try:
payload = await websocket.receive_text()
except WebSocketDisconnect:
cancel_event.set()
break
await handle_control_payload(payload)
if not paused and not cancel_event.is_set() and control_task is None:
control_task = asyncio.create_task(websocket.receive_text())
if cancel_event.is_set():
break
key = (p_idx + start_paragraph, s_idx, sentence)
if key != last_key:
last_key = key
ms_start = (cumulative_samples * 1000) // sample_rate
await websocket.send_json(
{
"type": "sentence",
"text": sentence,
"paragraph_index": int(p_idx + start_paragraph),
"sentence_index": int(s_idx),
"ms_start": ms_start,
"char_start": int(cs),
"char_end": int(ce),
# Size of the *next* binary message for this sentence in samples/bytes.
# Helps clients associate metadata with audio even if transport splits chunks.
"chunk_samples": int(len(audio_chunk) // 2),
"chunk_bytes": int(len(audio_chunk)),
}
)
await websocket.send_bytes(audio_chunk)
cumulative_samples += len(audio_chunk) // 2
# Accumulate PCM for FLAC encoding (downloads only).
if not realtime:
download_pcm_chunks.append(audio_chunk)
# Optional realtime pacing.
# - streaming: send roughly in-time to reduce client buffer bloat.
# - downloads: realtime=false sends as fast as synthesis allows.
if realtime:
expected_s = cumulative_samples / float(sample_rate)
elapsed_s = time.monotonic() - stream_t0
# Let the stream run slightly ahead to avoid stutter from
# small scheduling/network jitter.
ahead_s = 0.10
sleep_s = (expected_s - elapsed_s) - ahead_s
if sleep_s > 0:
await asyncio.sleep(min(sleep_s, 0.25))
# Properly clean up the control_task to avoid
# concurrent recv race with the outer message loop.
pending_command = None
if control_task is not None:
if control_task.done():
try:
pending_command = control_task.result()
except Exception:
pass
else:
control_task.cancel()
try:
await control_task
except (asyncio.CancelledError, Exception):
pass
control_task = None
# For downloads, encode accumulated PCM as FLAC and send.
if not realtime and download_pcm_chunks and not cancel_event.is_set():
try:
all_pcm = b"".join(download_pcm_chunks)
flac_data = app.state.tts.encode_pcm16_to_flac(
all_pcm, sample_rate=sample_rate
)
is_flac = flac_data[:4] == b"fLaC"
await websocket.send_json({
"type": "flac_data",
"encoding": "flac" if is_flac else "pcm_s16le",
"size": len(flac_data),
"sample_rate": sample_rate,
})
await websocket.send_bytes(flac_data)
except Exception as e:
logger.warning(f"FLAC encoding failed, downloads saved as PCM: {e}")
finally:
download_pcm_chunks.clear()
try:
await websocket.send_json(
{
"type": "chapter_complete",
"next_url": chapter.get("next_url"),
"prev_url": chapter.get("prev_url"),
}
)
except Exception:
pass # Client already disconnected
# If the client sent a new command while streaming,
# it will be picked up by the outer while-loop.
if pending_command is not None:
try:
json.loads(pending_command) # validate
except Exception:
pending_command = None
except Exception as e:
logger.error(f"Play stream error: {e}")
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass # Client already disconnected
else:
await websocket.send_json({"error": "Unknown command"})
except json.JSONDecodeError:
try:
await websocket.send_json({"error": "Invalid JSON"})
except Exception:
pass
except Exception as e:
logger.error(f"Error processing message: {e}")
traceback.print_exc()
try:
await websocket.send_json({"error": "Internal server error"})
except Exception:
pass
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)