import asyncio import json import logging import re import uuid import urllib.parse import requests from datetime import datetime from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import AsyncSessionLocal, get_session, init_db from app.models import Job, JobStatus, TranscriptSegment from app.job_control import ensure_job_control, request_cancel from app.job_limit import run_job_pipeline from app.schemas import ( CancelJobResponse, ChatRequest, ChatResponse, ChatSourceSnippet, JobCreateResponse, JobStatusResponse, JobUpdateRequest, TimelineResponse, TimelineWindow, TranscriptSegmentOut, job_to_status_response, ) from app.services.chroma_service import query_job from app.services.embedding_service import embed_query from app.services.gemini_service import answer_with_context from app.progress_hub import broadcast_job, subscribe, unsubscribe from app.timeline_util import segments_to_windows logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title=settings.app_name) app.add_middleware( CORSMiddleware, allow_origins=[o.strip() for o in settings.cors_origins.split(",") if o.strip()], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def _safe_name(name: str | None) -> str: base = name or "upload.bin" base = re.sub(r"[^a-zA-Z0-9._-]", "_", base)[-200:] return base or "upload.bin" @app.on_event("startup") async def on_startup() -> None: await init_db() @app.get("/api/health") async def health() -> dict: return {"status": "ok", "time": datetime.utcnow().isoformat() + "Z"} def _normalize_whisper_task(raw: str) -> str: t = (raw or "transcribe").strip().lower() return t if t in ("transcribe", "translate") else "transcribe" def _normalize_whisper_language(raw: str | None) -> str | None: if not raw: return None s = raw.strip().lower() if not s or s in ("auto", "detect", "none"): return None if len(s) > 12 or not re.match(r"^[a-z]{2}(-[a-z0-9]{1,8})?$", s): raise HTTPException(400, "Invalid whisper language code (e.g. hi, en, zh)") return s def _extract_youtube_id(url: str) -> str | None: parsed = urllib.parse.urlparse(url) if parsed.hostname in ('youtu.be', 'www.youtu.be'): return parsed.path[1:] if parsed.hostname in ('youtube.com', 'www.youtube.com'): q = urllib.parse.parse_qs(parsed.query) if 'v' in q: return q['v'][0] return None @app.post("/api/jobs", response_model=JobCreateResponse) async def create_job( session: AsyncSession = Depends(get_session), file: UploadFile | None = File(None), youtube_url: str | None = Form(None), whisper_language: str | None = Form(None), whisper_task: str = Form("transcribe"), ) -> JobCreateResponse: if not file and not youtube_url: raise HTTPException(400, "Must provide either file or youtube_url") job_id = str(uuid.uuid4()) lang = _normalize_whisper_language(whisper_language) task = _normalize_whisper_task(whisper_task) if youtube_url: yt_id = _extract_youtube_id(youtube_url) if not yt_id: raise HTTPException(400, "Invalid YouTube URL") # Fetch metadata using oEmbed title = "YouTube Video" thumbnail = None try: oembed_url = f"https://www.youtube.com/oembed?url=https://www.youtube.com/watch?v={yt_id}&format=json" resp = requests.get(oembed_url, timeout=5) if resp.status_code == 200: data = resp.json() title = data.get("title", title) thumbnail = data.get("thumbnail_url") except Exception as e: logger.warning(f"Could not fetch YouTube metadata: {e}") ensure_job_control(job_id) job = Job( id=job_id, status=JobStatus.pending.value, original_filename="YouTube URL", storage_path=f"youtube:{yt_id}", title=title, thumbnail=thumbnail, whisper_language=lang, whisper_task=task, ) session.add(job) await session.commit() else: sz = getattr(file, "size", None) if sz is not None and sz > settings.max_upload_bytes: raise HTTPException(413, "File exceeds configured maximum size") job_id = str(uuid.uuid4()) upload_dir = settings.data_dir / "uploads" / job_id upload_dir.mkdir(parents=True, exist_ok=True) dest = upload_dir / _safe_name(file.filename) written = 0 chunk_size = 1024 * 1024 try: with open(dest, "wb") as out: while True: chunk = await file.read(chunk_size) if not chunk: break written += len(chunk) if written > settings.max_upload_bytes: raise HTTPException(413, "File exceeds configured maximum size") out.write(chunk) except HTTPException: try: dest.unlink(missing_ok=True) except OSError: pass raise finally: await file.close() job = Job( id=job_id, status=JobStatus.pending.value, original_filename=file.filename or dest.name, storage_path=str(dest.resolve()), whisper_language=lang, whisper_task=task, ) session.add(job) await session.commit() task = asyncio.create_task(run_job_pipeline(job_id)) from app.job_control import register_job_task register_job_task(job_id, task) return JobCreateResponse(job_id=job_id) @app.post("/api/jobs/{job_id}/cancel", response_model=CancelJobResponse) async def cancel_job(job_id: str, session: AsyncSession = Depends(get_session)) -> CancelJobResponse: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(404, "Job not found") terminal = { JobStatus.ready.value, JobStatus.failed.value, JobStatus.cancelled.value, } if job.status in terminal: raise HTTPException(400, "Job is not running") ensure_job_control(job_id) request_cancel(job_id) log_json = job.activity_log_json or "[]" try: log = json.loads(log_json) if not isinstance(log, list): log = [] except json.JSONDecodeError: log = [] log.append( { "t": datetime.utcnow().isoformat() + "Z", "level": "info", "message": "User requested cancel — stopping immediately.", } ) await session.execute( update(Job) .where(Job.id == job_id) .values( status=JobStatus.cancelled.value, progress_message="Cancelled", activity_log_json=json.dumps(log[-120:]), updated_at=datetime.utcnow(), ) ) await session.commit() await broadcast_job(job_id) logger.info("cancel_api job_id=%s", job_id) return CancelJobResponse(job_id=job_id, status="cancelled", message="Cancellation registered.") @app.get("/api/jobs", response_model=list[JobStatusResponse]) async def list_jobs( session: AsyncSession = Depends(get_session), limit: int = 25, ) -> list[JobStatusResponse]: lim = min(100, max(1, limit)) result = await session.execute(select(Job).order_by(Job.created_at.desc()).limit(lim)) rows = list(result.scalars().all()) return [job_to_status_response(j) for j in rows] @app.get("/api/jobs/{job_id}", response_model=JobStatusResponse) async def get_job(job_id: str, session: AsyncSession = Depends(get_session)) -> JobStatusResponse: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(404, "Job not found") return job_to_status_response(job) @app.patch("/api/jobs/{job_id}", response_model=JobStatusResponse) async def update_job(job_id: str, body: JobUpdateRequest, session: AsyncSession = Depends(get_session)) -> JobStatusResponse: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(404, "Job not found") if body.title is not None: job.title = body.title if body.subject is not None: job.subject = body.subject if body.thumbnail is not None: job.thumbnail = body.thumbnail job.updated_at = datetime.utcnow() await session.commit() await session.refresh(job) return job_to_status_response(job) def _build_timeline_windows(job: Job, segments: list[TranscriptSegment], window_sec: float) -> list[TimelineWindow]: duration = float(job.duration_seconds or 0) flat = [(s.start_sec, s.end_sec, s.text) for s in segments] raw = segments_to_windows(flat, duration, window_sec) return [ TimelineWindow( window_index=w["window_index"], start_sec=w["start_sec"], end_sec=w["end_sec"], text=w["text"], ) for w in raw ] @app.get("/api/jobs/{job_id}/timeline", response_model=TimelineResponse) async def get_timeline( job_id: str, session: AsyncSession = Depends(get_session), window_sec: float | None = None, ) -> TimelineResponse: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(404, "Job not found") seg_result = await session.execute( select(TranscriptSegment) .where(TranscriptSegment.job_id == job_id) .order_by(TranscriptSegment.start_sec) ) segments = list(seg_result.scalars().all()) w = float(window_sec or settings.timeline_window_seconds) windows = _build_timeline_windows(job, segments, w) return TimelineResponse( job_id=job_id, window_seconds=w, total_windows=len(windows), duration_seconds=job.duration_seconds, windows=windows, ) @app.get("/api/jobs/{job_id}/segments", response_model=list[TranscriptSegmentOut]) async def get_segments( job_id: str, session: AsyncSession = Depends(get_session), offset: int = 0, limit: int = 500, ) -> list[TranscriptSegmentOut]: result = await session.execute(select(Job).where(Job.id == job_id)) if result.scalar_one_or_none() is None: raise HTTPException(404, "Job not found") q = ( select(TranscriptSegment) .where(TranscriptSegment.job_id == job_id) .order_by(TranscriptSegment.start_sec) .offset(max(0, offset)) .limit(min(2000, max(1, limit))) ) rows = (await session.execute(q)).scalars().all() return [TranscriptSegmentOut.model_validate(r) for r in rows] @app.post("/api/jobs/{job_id}/chat", response_model=ChatResponse) async def chat_job(job_id: str, body: ChatRequest, session: AsyncSession = Depends(get_session)) -> ChatResponse: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(404, "Job not found") if job.status != JobStatus.ready.value: raise HTTPException(400, "Transcription not ready yet") try: q_emb = embed_query(body.query) top_k = max(3, min(12, settings.chat_retrieval_top_k)) docs, metas, distances = query_job(job_id, q_emb, n_results=top_k) except Exception as exc: logger.exception("Retrieval failed") raise HTTPException(500, f"Retrieval failed: {exc}") from exc min_dist = min(distances) if distances else 1.0 query_lower = body.query.lower() is_generic_query = any(q in query_lower for q in [ "main topic", "summary", "summarize", "key takeaway", "core concept", "explain", "overview" ]) if min_dist > 0.7 and not is_generic_query: return ChatResponse( answer="I can only answer questions related to the video content.", sources=[] ) allow_fallback = is_generic_query or (min_dist > settings.chat_fallback_distance_threshold) def _snippet(s: str, n: int) -> str: t = " ".join((s or "").split()) if len(t) <= n: return t return t[: n - 1].rstrip() + "…" context_blocks: list[str] = [] sources: list[ChatSourceSnippet] = [] clip_ctx = max(160, settings.chat_context_chars_per_chunk) for doc, meta in zip(docs, metas): if not meta: continue wi = int(meta.get("window_index", 0)) start = float(meta.get("start_sec", 0)) end = float(meta.get("end_sec", 0)) body_txt = (doc or "").strip() label = f"[{start:.0f}s–{end:.0f}s] {_snippet(body_txt, clip_ctx)}" context_blocks.append(label) sources.append( ChatSourceSnippet( window_index=wi, start_sec=start, end_sec=end, snippet=_snippet(body_txt, 200), ) ) if not context_blocks: context_blocks = ["(No highly similar transcript windows; reply that the video may not cover this.)"] try: answer = answer_with_context(body.query, context_blocks, allow_fallback=allow_fallback) except ValueError as exc: raise HTTPException(503, "Gemini API key not configured. Set GEMINI_API_KEY.") from exc except Exception as exc: logger.exception("Gemini failed") raise HTTPException(500, f"Gemini error: {exc}") from exc return ChatResponse(answer=answer, sources=sources[:4]) @app.websocket("/api/jobs/{job_id}/ws") async def job_events_ws(websocket: WebSocket, job_id: str) -> None: await websocket.accept() async with AsyncSessionLocal() as session: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: await websocket.send_json({"type": "error", "detail": "Job not found"}) await websocket.close(code=4404) return snapshot = job_to_status_response(job).model_dump(mode="json") await websocket.send_json({"type": "job_status", "data": snapshot}) q = subscribe(job_id) try: while True: msg = await q.get() await websocket.send_json(msg) except WebSocketDisconnect: pass finally: unsubscribe(job_id, q)