well / app /main.py
zarox's picture
Upload 23 files
1c167a4 verified
import asyncio
import json
import logging
import re
import uuid
import urllib.parse
import requests
from datetime import datetime
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import AsyncSessionLocal, get_session, init_db
from app.models import Job, JobStatus, TranscriptSegment
from app.job_control import ensure_job_control, request_cancel
from app.job_limit import run_job_pipeline
from app.schemas import (
CancelJobResponse,
ChatRequest,
ChatResponse,
ChatSourceSnippet,
JobCreateResponse,
JobStatusResponse,
JobUpdateRequest,
TimelineResponse,
TimelineWindow,
TranscriptSegmentOut,
job_to_status_response,
)
from app.services.chroma_service import query_job
from app.services.embedding_service import embed_query
from app.services.gemini_service import answer_with_context
from app.progress_hub import broadcast_job, subscribe, unsubscribe
from app.timeline_util import segments_to_windows
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title=settings.app_name)
app.add_middleware(
CORSMiddleware,
allow_origins=[o.strip() for o in settings.cors_origins.split(",") if o.strip()],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _safe_name(name: str | None) -> str:
base = name or "upload.bin"
base = re.sub(r"[^a-zA-Z0-9._-]", "_", base)[-200:]
return base or "upload.bin"
@app.on_event("startup")
async def on_startup() -> None:
await init_db()
@app.get("/api/health")
async def health() -> dict:
return {"status": "ok", "time": datetime.utcnow().isoformat() + "Z"}
def _normalize_whisper_task(raw: str) -> str:
t = (raw or "transcribe").strip().lower()
return t if t in ("transcribe", "translate") else "transcribe"
def _normalize_whisper_language(raw: str | None) -> str | None:
if not raw:
return None
s = raw.strip().lower()
if not s or s in ("auto", "detect", "none"):
return None
if len(s) > 12 or not re.match(r"^[a-z]{2}(-[a-z0-9]{1,8})?$", s):
raise HTTPException(400, "Invalid whisper language code (e.g. hi, en, zh)")
return s
def _extract_youtube_id(url: str) -> str | None:
parsed = urllib.parse.urlparse(url)
if parsed.hostname in ('youtu.be', 'www.youtu.be'):
return parsed.path[1:]
if parsed.hostname in ('youtube.com', 'www.youtube.com'):
q = urllib.parse.parse_qs(parsed.query)
if 'v' in q:
return q['v'][0]
return None
@app.post("/api/jobs", response_model=JobCreateResponse)
async def create_job(
session: AsyncSession = Depends(get_session),
file: UploadFile | None = File(None),
youtube_url: str | None = Form(None),
whisper_language: str | None = Form(None),
whisper_task: str = Form("transcribe"),
) -> JobCreateResponse:
if not file and not youtube_url:
raise HTTPException(400, "Must provide either file or youtube_url")
job_id = str(uuid.uuid4())
lang = _normalize_whisper_language(whisper_language)
task = _normalize_whisper_task(whisper_task)
if youtube_url:
yt_id = _extract_youtube_id(youtube_url)
if not yt_id:
raise HTTPException(400, "Invalid YouTube URL")
# Fetch metadata using oEmbed
title = "YouTube Video"
thumbnail = None
try:
oembed_url = f"https://www.youtube.com/oembed?url=https://www.youtube.com/watch?v={yt_id}&format=json"
resp = requests.get(oembed_url, timeout=5)
if resp.status_code == 200:
data = resp.json()
title = data.get("title", title)
thumbnail = data.get("thumbnail_url")
except Exception as e:
logger.warning(f"Could not fetch YouTube metadata: {e}")
ensure_job_control(job_id)
job = Job(
id=job_id,
status=JobStatus.pending.value,
original_filename="YouTube URL",
storage_path=f"youtube:{yt_id}",
title=title,
thumbnail=thumbnail,
whisper_language=lang,
whisper_task=task,
)
session.add(job)
await session.commit()
else:
sz = getattr(file, "size", None)
if sz is not None and sz > settings.max_upload_bytes:
raise HTTPException(413, "File exceeds configured maximum size")
job_id = str(uuid.uuid4())
upload_dir = settings.data_dir / "uploads" / job_id
upload_dir.mkdir(parents=True, exist_ok=True)
dest = upload_dir / _safe_name(file.filename)
written = 0
chunk_size = 1024 * 1024
try:
with open(dest, "wb") as out:
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
written += len(chunk)
if written > settings.max_upload_bytes:
raise HTTPException(413, "File exceeds configured maximum size")
out.write(chunk)
except HTTPException:
try:
dest.unlink(missing_ok=True)
except OSError:
pass
raise
finally:
await file.close()
job = Job(
id=job_id,
status=JobStatus.pending.value,
original_filename=file.filename or dest.name,
storage_path=str(dest.resolve()),
whisper_language=lang,
whisper_task=task,
)
session.add(job)
await session.commit()
task = asyncio.create_task(run_job_pipeline(job_id))
from app.job_control import register_job_task
register_job_task(job_id, task)
return JobCreateResponse(job_id=job_id)
@app.post("/api/jobs/{job_id}/cancel", response_model=CancelJobResponse)
async def cancel_job(job_id: str, session: AsyncSession = Depends(get_session)) -> CancelJobResponse:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(404, "Job not found")
terminal = {
JobStatus.ready.value,
JobStatus.failed.value,
JobStatus.cancelled.value,
}
if job.status in terminal:
raise HTTPException(400, "Job is not running")
ensure_job_control(job_id)
request_cancel(job_id)
log_json = job.activity_log_json or "[]"
try:
log = json.loads(log_json)
if not isinstance(log, list):
log = []
except json.JSONDecodeError:
log = []
log.append(
{
"t": datetime.utcnow().isoformat() + "Z",
"level": "info",
"message": "User requested cancel — stopping immediately.",
}
)
await session.execute(
update(Job)
.where(Job.id == job_id)
.values(
status=JobStatus.cancelled.value,
progress_message="Cancelled",
activity_log_json=json.dumps(log[-120:]),
updated_at=datetime.utcnow(),
)
)
await session.commit()
await broadcast_job(job_id)
logger.info("cancel_api job_id=%s", job_id)
return CancelJobResponse(job_id=job_id, status="cancelled", message="Cancellation registered.")
@app.get("/api/jobs", response_model=list[JobStatusResponse])
async def list_jobs(
session: AsyncSession = Depends(get_session),
limit: int = 25,
) -> list[JobStatusResponse]:
lim = min(100, max(1, limit))
result = await session.execute(select(Job).order_by(Job.created_at.desc()).limit(lim))
rows = list(result.scalars().all())
return [job_to_status_response(j) for j in rows]
@app.get("/api/jobs/{job_id}", response_model=JobStatusResponse)
async def get_job(job_id: str, session: AsyncSession = Depends(get_session)) -> JobStatusResponse:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(404, "Job not found")
return job_to_status_response(job)
@app.patch("/api/jobs/{job_id}", response_model=JobStatusResponse)
async def update_job(job_id: str, body: JobUpdateRequest, session: AsyncSession = Depends(get_session)) -> JobStatusResponse:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(404, "Job not found")
if body.title is not None:
job.title = body.title
if body.subject is not None:
job.subject = body.subject
if body.thumbnail is not None:
job.thumbnail = body.thumbnail
job.updated_at = datetime.utcnow()
await session.commit()
await session.refresh(job)
return job_to_status_response(job)
def _build_timeline_windows(job: Job, segments: list[TranscriptSegment], window_sec: float) -> list[TimelineWindow]:
duration = float(job.duration_seconds or 0)
flat = [(s.start_sec, s.end_sec, s.text) for s in segments]
raw = segments_to_windows(flat, duration, window_sec)
return [
TimelineWindow(
window_index=w["window_index"],
start_sec=w["start_sec"],
end_sec=w["end_sec"],
text=w["text"],
)
for w in raw
]
@app.get("/api/jobs/{job_id}/timeline", response_model=TimelineResponse)
async def get_timeline(
job_id: str,
session: AsyncSession = Depends(get_session),
window_sec: float | None = None,
) -> TimelineResponse:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(404, "Job not found")
seg_result = await session.execute(
select(TranscriptSegment)
.where(TranscriptSegment.job_id == job_id)
.order_by(TranscriptSegment.start_sec)
)
segments = list(seg_result.scalars().all())
w = float(window_sec or settings.timeline_window_seconds)
windows = _build_timeline_windows(job, segments, w)
return TimelineResponse(
job_id=job_id,
window_seconds=w,
total_windows=len(windows),
duration_seconds=job.duration_seconds,
windows=windows,
)
@app.get("/api/jobs/{job_id}/segments", response_model=list[TranscriptSegmentOut])
async def get_segments(
job_id: str,
session: AsyncSession = Depends(get_session),
offset: int = 0,
limit: int = 500,
) -> list[TranscriptSegmentOut]:
result = await session.execute(select(Job).where(Job.id == job_id))
if result.scalar_one_or_none() is None:
raise HTTPException(404, "Job not found")
q = (
select(TranscriptSegment)
.where(TranscriptSegment.job_id == job_id)
.order_by(TranscriptSegment.start_sec)
.offset(max(0, offset))
.limit(min(2000, max(1, limit)))
)
rows = (await session.execute(q)).scalars().all()
return [TranscriptSegmentOut.model_validate(r) for r in rows]
@app.post("/api/jobs/{job_id}/chat", response_model=ChatResponse)
async def chat_job(job_id: str, body: ChatRequest, session: AsyncSession = Depends(get_session)) -> ChatResponse:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(404, "Job not found")
if job.status != JobStatus.ready.value:
raise HTTPException(400, "Transcription not ready yet")
try:
q_emb = embed_query(body.query)
top_k = max(3, min(12, settings.chat_retrieval_top_k))
docs, metas, distances = query_job(job_id, q_emb, n_results=top_k)
except Exception as exc:
logger.exception("Retrieval failed")
raise HTTPException(500, f"Retrieval failed: {exc}") from exc
min_dist = min(distances) if distances else 1.0
query_lower = body.query.lower()
is_generic_query = any(q in query_lower for q in [
"main topic", "summary", "summarize", "key takeaway", "core concept",
"explain", "overview"
])
if min_dist > 0.7 and not is_generic_query:
return ChatResponse(
answer="I can only answer questions related to the video content.",
sources=[]
)
allow_fallback = is_generic_query or (min_dist > settings.chat_fallback_distance_threshold)
def _snippet(s: str, n: int) -> str:
t = " ".join((s or "").split())
if len(t) <= n:
return t
return t[: n - 1].rstrip() + "…"
context_blocks: list[str] = []
sources: list[ChatSourceSnippet] = []
clip_ctx = max(160, settings.chat_context_chars_per_chunk)
for doc, meta in zip(docs, metas):
if not meta:
continue
wi = int(meta.get("window_index", 0))
start = float(meta.get("start_sec", 0))
end = float(meta.get("end_sec", 0))
body_txt = (doc or "").strip()
label = f"[{start:.0f}s–{end:.0f}s] {_snippet(body_txt, clip_ctx)}"
context_blocks.append(label)
sources.append(
ChatSourceSnippet(
window_index=wi,
start_sec=start,
end_sec=end,
snippet=_snippet(body_txt, 200),
)
)
if not context_blocks:
context_blocks = ["(No highly similar transcript windows; reply that the video may not cover this.)"]
try:
answer = answer_with_context(body.query, context_blocks, allow_fallback=allow_fallback)
except ValueError as exc:
raise HTTPException(503, "Gemini API key not configured. Set GEMINI_API_KEY.") from exc
except Exception as exc:
logger.exception("Gemini failed")
raise HTTPException(500, f"Gemini error: {exc}") from exc
return ChatResponse(answer=answer, sources=sources[:4])
@app.websocket("/api/jobs/{job_id}/ws")
async def job_events_ws(websocket: WebSocket, job_id: str) -> None:
await websocket.accept()
async with AsyncSessionLocal() as session:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
await websocket.send_json({"type": "error", "detail": "Job not found"})
await websocket.close(code=4404)
return
snapshot = job_to_status_response(job).model_dump(mode="json")
await websocket.send_json({"type": "job_status", "data": snapshot})
q = subscribe(job_id)
try:
while True:
msg = await q.get()
await websocket.send_json(msg)
except WebSocketDisconnect:
pass
finally:
unsubscribe(job_id, q)