| import asyncio |
| import json |
| import logging |
| import re |
| import uuid |
| import urllib.parse |
| import requests |
| from datetime import datetime |
|
|
| from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| from sqlalchemy import select, update |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from app.config import settings |
| from app.database import AsyncSessionLocal, get_session, init_db |
| from app.models import Job, JobStatus, TranscriptSegment |
| from app.job_control import ensure_job_control, request_cancel |
| from app.job_limit import run_job_pipeline |
| from app.schemas import ( |
| CancelJobResponse, |
| ChatRequest, |
| ChatResponse, |
| ChatSourceSnippet, |
| JobCreateResponse, |
| JobStatusResponse, |
| JobUpdateRequest, |
| TimelineResponse, |
| TimelineWindow, |
| TranscriptSegmentOut, |
| job_to_status_response, |
| ) |
| from app.services.chroma_service import query_job |
| from app.services.embedding_service import embed_query |
| from app.services.gemini_service import answer_with_context |
| from app.progress_hub import broadcast_job, subscribe, unsubscribe |
| from app.timeline_util import segments_to_windows |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI(title=settings.app_name) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=[o.strip() for o in settings.cors_origins.split(",") if o.strip()], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def _safe_name(name: str | None) -> str: |
| base = name or "upload.bin" |
| base = re.sub(r"[^a-zA-Z0-9._-]", "_", base)[-200:] |
| return base or "upload.bin" |
|
|
|
|
| @app.on_event("startup") |
| async def on_startup() -> None: |
| await init_db() |
|
|
|
|
| @app.get("/api/health") |
| async def health() -> dict: |
| return {"status": "ok", "time": datetime.utcnow().isoformat() + "Z"} |
|
|
|
|
| def _normalize_whisper_task(raw: str) -> str: |
| t = (raw or "transcribe").strip().lower() |
| return t if t in ("transcribe", "translate") else "transcribe" |
|
|
|
|
| def _normalize_whisper_language(raw: str | None) -> str | None: |
| if not raw: |
| return None |
| s = raw.strip().lower() |
| if not s or s in ("auto", "detect", "none"): |
| return None |
| if len(s) > 12 or not re.match(r"^[a-z]{2}(-[a-z0-9]{1,8})?$", s): |
| raise HTTPException(400, "Invalid whisper language code (e.g. hi, en, zh)") |
| return s |
|
|
|
|
| def _extract_youtube_id(url: str) -> str | None: |
| parsed = urllib.parse.urlparse(url) |
| if parsed.hostname in ('youtu.be', 'www.youtu.be'): |
| return parsed.path[1:] |
| if parsed.hostname in ('youtube.com', 'www.youtube.com'): |
| q = urllib.parse.parse_qs(parsed.query) |
| if 'v' in q: |
| return q['v'][0] |
| return None |
|
|
| @app.post("/api/jobs", response_model=JobCreateResponse) |
| async def create_job( |
| session: AsyncSession = Depends(get_session), |
| file: UploadFile | None = File(None), |
| youtube_url: str | None = Form(None), |
| whisper_language: str | None = Form(None), |
| whisper_task: str = Form("transcribe"), |
| ) -> JobCreateResponse: |
| if not file and not youtube_url: |
| raise HTTPException(400, "Must provide either file or youtube_url") |
|
|
| job_id = str(uuid.uuid4()) |
| lang = _normalize_whisper_language(whisper_language) |
| task = _normalize_whisper_task(whisper_task) |
|
|
| if youtube_url: |
| yt_id = _extract_youtube_id(youtube_url) |
| if not yt_id: |
| raise HTTPException(400, "Invalid YouTube URL") |
| |
| |
| title = "YouTube Video" |
| thumbnail = None |
| try: |
| oembed_url = f"https://www.youtube.com/oembed?url=https://www.youtube.com/watch?v={yt_id}&format=json" |
| resp = requests.get(oembed_url, timeout=5) |
| if resp.status_code == 200: |
| data = resp.json() |
| title = data.get("title", title) |
| thumbnail = data.get("thumbnail_url") |
| except Exception as e: |
| logger.warning(f"Could not fetch YouTube metadata: {e}") |
|
|
| ensure_job_control(job_id) |
| job = Job( |
| id=job_id, |
| status=JobStatus.pending.value, |
| original_filename="YouTube URL", |
| storage_path=f"youtube:{yt_id}", |
| title=title, |
| thumbnail=thumbnail, |
| whisper_language=lang, |
| whisper_task=task, |
| ) |
| session.add(job) |
| await session.commit() |
| else: |
| sz = getattr(file, "size", None) |
| if sz is not None and sz > settings.max_upload_bytes: |
| raise HTTPException(413, "File exceeds configured maximum size") |
|
|
| job_id = str(uuid.uuid4()) |
| upload_dir = settings.data_dir / "uploads" / job_id |
| upload_dir.mkdir(parents=True, exist_ok=True) |
| dest = upload_dir / _safe_name(file.filename) |
|
|
| written = 0 |
| chunk_size = 1024 * 1024 |
| try: |
| with open(dest, "wb") as out: |
| while True: |
| chunk = await file.read(chunk_size) |
| if not chunk: |
| break |
| written += len(chunk) |
| if written > settings.max_upload_bytes: |
| raise HTTPException(413, "File exceeds configured maximum size") |
| out.write(chunk) |
| except HTTPException: |
| try: |
| dest.unlink(missing_ok=True) |
| except OSError: |
| pass |
| raise |
| finally: |
| await file.close() |
|
|
| job = Job( |
| id=job_id, |
| status=JobStatus.pending.value, |
| original_filename=file.filename or dest.name, |
| storage_path=str(dest.resolve()), |
| whisper_language=lang, |
| whisper_task=task, |
| ) |
| session.add(job) |
| await session.commit() |
|
|
| task = asyncio.create_task(run_job_pipeline(job_id)) |
| from app.job_control import register_job_task |
| register_job_task(job_id, task) |
| return JobCreateResponse(job_id=job_id) |
|
|
|
|
| @app.post("/api/jobs/{job_id}/cancel", response_model=CancelJobResponse) |
| async def cancel_job(job_id: str, session: AsyncSession = Depends(get_session)) -> CancelJobResponse: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| raise HTTPException(404, "Job not found") |
| terminal = { |
| JobStatus.ready.value, |
| JobStatus.failed.value, |
| JobStatus.cancelled.value, |
| } |
| if job.status in terminal: |
| raise HTTPException(400, "Job is not running") |
| ensure_job_control(job_id) |
| request_cancel(job_id) |
|
|
| log_json = job.activity_log_json or "[]" |
| try: |
| log = json.loads(log_json) |
| if not isinstance(log, list): |
| log = [] |
| except json.JSONDecodeError: |
| log = [] |
|
|
| log.append( |
| { |
| "t": datetime.utcnow().isoformat() + "Z", |
| "level": "info", |
| "message": "User requested cancel — stopping immediately.", |
| } |
| ) |
| await session.execute( |
| update(Job) |
| .where(Job.id == job_id) |
| .values( |
| status=JobStatus.cancelled.value, |
| progress_message="Cancelled", |
| activity_log_json=json.dumps(log[-120:]), |
| updated_at=datetime.utcnow(), |
| ) |
| ) |
| await session.commit() |
| await broadcast_job(job_id) |
| logger.info("cancel_api job_id=%s", job_id) |
| return CancelJobResponse(job_id=job_id, status="cancelled", message="Cancellation registered.") |
|
|
|
|
| @app.get("/api/jobs", response_model=list[JobStatusResponse]) |
| async def list_jobs( |
| session: AsyncSession = Depends(get_session), |
| limit: int = 25, |
| ) -> list[JobStatusResponse]: |
| lim = min(100, max(1, limit)) |
| result = await session.execute(select(Job).order_by(Job.created_at.desc()).limit(lim)) |
| rows = list(result.scalars().all()) |
| return [job_to_status_response(j) for j in rows] |
|
|
|
|
| @app.get("/api/jobs/{job_id}", response_model=JobStatusResponse) |
| async def get_job(job_id: str, session: AsyncSession = Depends(get_session)) -> JobStatusResponse: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| raise HTTPException(404, "Job not found") |
| return job_to_status_response(job) |
|
|
|
|
| @app.patch("/api/jobs/{job_id}", response_model=JobStatusResponse) |
| async def update_job(job_id: str, body: JobUpdateRequest, session: AsyncSession = Depends(get_session)) -> JobStatusResponse: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| raise HTTPException(404, "Job not found") |
| if body.title is not None: |
| job.title = body.title |
| if body.subject is not None: |
| job.subject = body.subject |
| if body.thumbnail is not None: |
| job.thumbnail = body.thumbnail |
| job.updated_at = datetime.utcnow() |
| await session.commit() |
| await session.refresh(job) |
| return job_to_status_response(job) |
|
|
|
|
| def _build_timeline_windows(job: Job, segments: list[TranscriptSegment], window_sec: float) -> list[TimelineWindow]: |
| duration = float(job.duration_seconds or 0) |
| flat = [(s.start_sec, s.end_sec, s.text) for s in segments] |
| raw = segments_to_windows(flat, duration, window_sec) |
| return [ |
| TimelineWindow( |
| window_index=w["window_index"], |
| start_sec=w["start_sec"], |
| end_sec=w["end_sec"], |
| text=w["text"], |
| ) |
| for w in raw |
| ] |
|
|
|
|
| @app.get("/api/jobs/{job_id}/timeline", response_model=TimelineResponse) |
| async def get_timeline( |
| job_id: str, |
| session: AsyncSession = Depends(get_session), |
| window_sec: float | None = None, |
| ) -> TimelineResponse: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| raise HTTPException(404, "Job not found") |
| seg_result = await session.execute( |
| select(TranscriptSegment) |
| .where(TranscriptSegment.job_id == job_id) |
| .order_by(TranscriptSegment.start_sec) |
| ) |
| segments = list(seg_result.scalars().all()) |
| w = float(window_sec or settings.timeline_window_seconds) |
| windows = _build_timeline_windows(job, segments, w) |
| return TimelineResponse( |
| job_id=job_id, |
| window_seconds=w, |
| total_windows=len(windows), |
| duration_seconds=job.duration_seconds, |
| windows=windows, |
| ) |
|
|
|
|
| @app.get("/api/jobs/{job_id}/segments", response_model=list[TranscriptSegmentOut]) |
| async def get_segments( |
| job_id: str, |
| session: AsyncSession = Depends(get_session), |
| offset: int = 0, |
| limit: int = 500, |
| ) -> list[TranscriptSegmentOut]: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| if result.scalar_one_or_none() is None: |
| raise HTTPException(404, "Job not found") |
| q = ( |
| select(TranscriptSegment) |
| .where(TranscriptSegment.job_id == job_id) |
| .order_by(TranscriptSegment.start_sec) |
| .offset(max(0, offset)) |
| .limit(min(2000, max(1, limit))) |
| ) |
| rows = (await session.execute(q)).scalars().all() |
| return [TranscriptSegmentOut.model_validate(r) for r in rows] |
|
|
|
|
| @app.post("/api/jobs/{job_id}/chat", response_model=ChatResponse) |
| async def chat_job(job_id: str, body: ChatRequest, session: AsyncSession = Depends(get_session)) -> ChatResponse: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| raise HTTPException(404, "Job not found") |
| if job.status != JobStatus.ready.value: |
| raise HTTPException(400, "Transcription not ready yet") |
|
|
| try: |
| q_emb = embed_query(body.query) |
| top_k = max(3, min(12, settings.chat_retrieval_top_k)) |
| docs, metas, distances = query_job(job_id, q_emb, n_results=top_k) |
| except Exception as exc: |
| logger.exception("Retrieval failed") |
| raise HTTPException(500, f"Retrieval failed: {exc}") from exc |
|
|
| min_dist = min(distances) if distances else 1.0 |
| |
| query_lower = body.query.lower() |
| is_generic_query = any(q in query_lower for q in [ |
| "main topic", "summary", "summarize", "key takeaway", "core concept", |
| "explain", "overview" |
| ]) |
| |
| if min_dist > 0.7 and not is_generic_query: |
| return ChatResponse( |
| answer="I can only answer questions related to the video content.", |
| sources=[] |
| ) |
|
|
| allow_fallback = is_generic_query or (min_dist > settings.chat_fallback_distance_threshold) |
|
|
| def _snippet(s: str, n: int) -> str: |
| t = " ".join((s or "").split()) |
| if len(t) <= n: |
| return t |
| return t[: n - 1].rstrip() + "…" |
|
|
| context_blocks: list[str] = [] |
| sources: list[ChatSourceSnippet] = [] |
| clip_ctx = max(160, settings.chat_context_chars_per_chunk) |
| |
| for doc, meta in zip(docs, metas): |
| if not meta: |
| continue |
| wi = int(meta.get("window_index", 0)) |
| start = float(meta.get("start_sec", 0)) |
| end = float(meta.get("end_sec", 0)) |
| body_txt = (doc or "").strip() |
| label = f"[{start:.0f}s–{end:.0f}s] {_snippet(body_txt, clip_ctx)}" |
| context_blocks.append(label) |
| sources.append( |
| ChatSourceSnippet( |
| window_index=wi, |
| start_sec=start, |
| end_sec=end, |
| snippet=_snippet(body_txt, 200), |
| ) |
| ) |
|
|
| if not context_blocks: |
| context_blocks = ["(No highly similar transcript windows; reply that the video may not cover this.)"] |
|
|
| try: |
| answer = answer_with_context(body.query, context_blocks, allow_fallback=allow_fallback) |
| except ValueError as exc: |
| raise HTTPException(503, "Gemini API key not configured. Set GEMINI_API_KEY.") from exc |
| except Exception as exc: |
| logger.exception("Gemini failed") |
| raise HTTPException(500, f"Gemini error: {exc}") from exc |
|
|
| return ChatResponse(answer=answer, sources=sources[:4]) |
|
|
|
|
| @app.websocket("/api/jobs/{job_id}/ws") |
| async def job_events_ws(websocket: WebSocket, job_id: str) -> None: |
| await websocket.accept() |
| async with AsyncSessionLocal() as session: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| await websocket.send_json({"type": "error", "detail": "Job not found"}) |
| await websocket.close(code=4404) |
| return |
| snapshot = job_to_status_response(job).model_dump(mode="json") |
| await websocket.send_json({"type": "job_status", "data": snapshot}) |
| q = subscribe(job_id) |
| try: |
| while True: |
| msg = await q.get() |
| await websocket.send_json(msg) |
| except WebSocketDisconnect: |
| pass |
| finally: |
| unsubscribe(job_id, q) |
|
|