Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from contextlib import asynccontextmanager | |
| import os | |
| import re | |
| import shutil | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit | |
| from uuid import uuid4 | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi import APIRouter, BackgroundTasks, Depends, FastAPI, File, Form, HTTPException, Request, UploadFile, status | |
| from fastapi.responses import FileResponse, RedirectResponse | |
| from pydantic import BaseModel, Field | |
| from sqlalchemy import text | |
| from sqlalchemy.orm import Session | |
| from auth.oauth import ( | |
| HFOAuthError, | |
| build_hf_authorize_url, | |
| exchange_code_for_hf_user, | |
| generate_oauth_state, | |
| is_valid_oauth_state, | |
| ) | |
| from auth.session import ( | |
| AUTH_MODE_DEV, | |
| AUTH_MODE_HF, | |
| AuthBridgeTokenError, | |
| CurrentUser, | |
| clear_session_user, | |
| configure_session_middleware, | |
| decode_auth_bridge_token, | |
| generate_auth_bridge_token, | |
| get_auth_mode, | |
| get_session_user, | |
| require_current_user, | |
| set_session_user, | |
| ) | |
| from data import crud | |
| from data.db import get_db, init_db, SessionLocal | |
| from src.ingestion.extractors import URLValidationError, validate_ingestion_url | |
| from src.ingestion.service import ingest_source, query_notebook_chunks | |
| from src.artifacts.report_generator import ReportGenerator | |
| from src.artifacts.quiz_generator import QuizGenerator | |
| from src.artifacts.podcast_generator import PodcastGenerator | |
| from utils.llm_client import LLMConfigError, generate_chat_completion | |
| from utils.prompt_templates import build_rag_system_prompt, build_rag_user_prompt | |
| async def lifespan(_: FastAPI): | |
| init_db() | |
| yield | |
| app = FastAPI(title="NotebookLM Clone API", version="0.1.0", lifespan=lifespan) | |
| configure_session_middleware(app) | |
| auth_router = APIRouter(prefix="/auth", tags=["auth"]) | |
| notebooks_router = APIRouter(prefix="/notebooks", tags=["notebooks"]) | |
| sources_router = APIRouter(prefix="/sources", tags=["sources"]) | |
| threads_router = APIRouter(prefix="/threads", tags=["threads"]) | |
| class NotebookCreateRequest(BaseModel): | |
| title: str = Field(min_length=1, max_length=255) | |
| class NotebookUpdateRequest(BaseModel): | |
| title: str = Field(min_length=1, max_length=255) | |
| class NotebookResponse(BaseModel): | |
| id: int | |
| owner_user_id: int | |
| title: str | |
| class SourceCreateRequest(BaseModel): | |
| type: str = Field(min_length=1, max_length=50) | |
| title: str | None = Field(default=None, max_length=255) | |
| original_name: str | None = Field(default=None, max_length=1024) | |
| url: str | None = Field(default=None, max_length=2048) | |
| storage_path: str | None = Field(default=None, max_length=1024) | |
| status: str = Field(default="pending", max_length=50) | |
| class SourceResponse(BaseModel): | |
| id: int | |
| notebook_id: int | |
| type: str | |
| title: str | None | |
| original_name: str | None | |
| url: str | None | |
| storage_path: str | None | |
| status: str | |
| ingested_at: datetime | None | |
| class ThreadCreateRequest(BaseModel): | |
| title: str | None = Field(default=None, max_length=255) | |
| class ThreadResponse(BaseModel): | |
| id: int | |
| notebook_id: int | |
| title: str | None | |
| created_at: datetime | |
| class CitationResponse(BaseModel): | |
| source_title: str | None = None | |
| source_id: int | |
| chunk_ref: str | None = None | |
| quote: str | None = None | |
| score: float | None = None | |
| class MessageResponse(BaseModel): | |
| id: int | |
| thread_id: int | |
| role: str | |
| content: str | |
| created_at: datetime | |
| citations: list[CitationResponse] = Field(default_factory=list) | |
| class ChatRequest(BaseModel): | |
| question: str = Field(min_length=1) | |
| top_k: int = Field(default=5, ge=1, le=12) | |
| class ChatResponse(BaseModel): | |
| user_message: MessageResponse | |
| assistant_message: MessageResponse | |
| citations: list[CitationResponse] | |
| class QuizGenerateRequest(BaseModel): | |
| num_questions: int = Field(default=5, ge=1, le=20) | |
| difficulty: str = Field(default="medium") | |
| topic_focus: str | None = None | |
| title: str | None = None | |
| class ReportGenerateRequest(BaseModel): | |
| detail_level: str = Field(default="medium") | |
| topic_focus: str | None = None | |
| title: str | None = None | |
| class PodcastGenerateRequest(BaseModel): | |
| duration: str = Field(default="5min") | |
| topic_focus: str | None = None | |
| title: str | None = None | |
| class DevLoginRequest(BaseModel): | |
| email: str | None = None | |
| display_name: str | None = None | |
| class SessionUserResponse(BaseModel): | |
| id: int | |
| email: str | |
| display_name: str | None = None | |
| avatar_url: str | None = None | |
| class AuthStatusResponse(BaseModel): | |
| mode: str | |
| authenticated: bool | |
| user: SessionUserResponse | None = None | |
| login_url: str | None = None | |
| class AuthBridgeExchangeRequest(BaseModel): | |
| token: str = Field(min_length=1) | |
| class NotebookDeleteResponse(BaseModel): | |
| status: str | |
| notebook_id: int | |
| class ArtifactResponse(BaseModel): | |
| id: int | |
| notebook_id: int | |
| type: str | |
| title: str | None | |
| status: str | |
| content: str | None | |
| file_path: str | None | |
| metadata: dict | None | |
| error_message: str | None | |
| created_at: datetime | |
| generated_at: datetime | None | |
| MAX_HISTORY_MESSAGES = 8 | |
| MAX_HISTORY_CHARS_PER_MESSAGE = 1000 | |
| MAX_UPLOAD_FILENAME_LENGTH = 255 | |
| SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+") | |
| UPLOADS_ROOT = Path(os.getenv("STORAGE_BASE_DIR", "data")) / "uploads" | |
| def _build_conversation_history( | |
| thread_messages: list, max_messages: int = MAX_HISTORY_MESSAGES | |
| ) -> list[str]: | |
| history_slice = thread_messages[-max_messages:] if len(thread_messages) > max_messages else thread_messages | |
| rows: list[str] = [] | |
| for msg in history_slice: | |
| role = str(msg.role).strip().lower() | |
| content = str(msg.content or "").strip() | |
| if not content: | |
| continue | |
| if len(content) > MAX_HISTORY_CHARS_PER_MESSAGE: | |
| content = content[:MAX_HISTORY_CHARS_PER_MESSAGE] + "..." | |
| rows.append(f"{role}: {content}") | |
| return rows | |
| def _auth_callback_url(request: Request) -> str: | |
| configured = os.getenv("HF_OAUTH_REDIRECT_URI", "").strip() | |
| if configured: | |
| return configured | |
| base = str(request.base_url).rstrip("/") | |
| return f"{base}/auth/callback" | |
| def _auth_status_payload(request: Request) -> AuthStatusResponse: | |
| mode = get_auth_mode() | |
| session_user = get_session_user(request) | |
| user_payload = ( | |
| SessionUserResponse( | |
| id=session_user.id, | |
| email=session_user.email, | |
| display_name=session_user.display_name, | |
| avatar_url=session_user.avatar_url, | |
| ) | |
| if session_user | |
| else None | |
| ) | |
| return AuthStatusResponse( | |
| mode=mode, | |
| authenticated=(session_user is not None), | |
| user=user_payload, | |
| login_url="/auth/login" if mode == AUTH_MODE_HF else None, | |
| ) | |
| def _append_query_param(url: str, key: str, value: str) -> str: | |
| split = urlsplit(url) | |
| query_items = dict(parse_qsl(split.query, keep_blank_values=True)) | |
| query_items[key] = value | |
| updated_query = urlencode(query_items) | |
| return urlunsplit((split.scheme, split.netloc, split.path, updated_query, split.fragment)) | |
| def _sanitize_upload_filename(filename: str | None) -> str: | |
| raw_name = Path(str(filename or "")).name.replace("\x00", "").strip() | |
| sanitized = SAFE_FILENAME_RE.sub("_", raw_name).strip("._-") | |
| if not sanitized: | |
| sanitized = f"upload_{uuid4().hex[:10]}.bin" | |
| if len(sanitized) > MAX_UPLOAD_FILENAME_LENGTH: | |
| ext = Path(sanitized).suffix[:20] | |
| stem_limit = max(1, MAX_UPLOAD_FILENAME_LENGTH - len(ext)) | |
| sanitized = f"{Path(sanitized).stem[:stem_limit]}{ext}" | |
| return sanitized | |
| def _resolve_notebook_upload_path(notebook_id: int, filename: str | None) -> Path: | |
| upload_dir = UPLOADS_ROOT / f"notebook_{notebook_id}" | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| upload_dir_resolved = upload_dir.resolve() | |
| safe_name = _sanitize_upload_filename(filename) | |
| destination = (upload_dir_resolved / safe_name).resolve() | |
| if destination.parent != upload_dir_resolved: | |
| raise HTTPException(status_code=400, detail="Invalid upload filename.") | |
| if destination.exists(): | |
| destination = (upload_dir_resolved / f"{destination.stem}_{uuid4().hex[:8]}{destination.suffix}").resolve() | |
| return destination | |
| def _remove_tree_within_root(root: Path, target: Path) -> None: | |
| if not target.exists(): | |
| return | |
| root_resolved = root.resolve() | |
| target_resolved = target.resolve() | |
| if target_resolved == root_resolved or root_resolved not in target_resolved.parents: | |
| raise RuntimeError(f"Refusing to delete path outside root: {target_resolved}") | |
| def _onerror(_func: Any, path: str, _exc_info: Any) -> None: | |
| """Handle Windows locked / read-only files by forcing writable.""" | |
| try: | |
| os.chmod(path, 0o777) | |
| os.remove(path) | |
| except OSError: | |
| pass | |
| # Try up to 3 times to handle file locks (e.g. ChromaDB on Windows) | |
| import gc | |
| for attempt in range(3): | |
| try: | |
| shutil.rmtree(target_resolved, onerror=_onerror) | |
| return | |
| except OSError: | |
| if attempt == 2: | |
| raise | |
| gc.collect() | |
| import time | |
| time.sleep(0.5) | |
| def _cleanup_notebook_storage(owner_user_id: int, notebook_id: int) -> None: | |
| storage_base = Path(os.getenv("STORAGE_BASE_DIR", "data")) | |
| notebook_root = storage_base / "users" / str(owner_user_id) / "notebooks" | |
| notebook_path = notebook_root / str(notebook_id) | |
| # Release any ChromaDB connections to this notebook before deleting | |
| chroma_dir = notebook_path / "chroma" | |
| if chroma_dir.exists(): | |
| try: | |
| from src.ingestion.vectorstore import ChromaAdapter | |
| store = ChromaAdapter(persist_directory=str(chroma_dir)) | |
| collection_name = f"user_{owner_user_id}_notebook_{notebook_id}" | |
| try: | |
| store._client.delete_collection(collection_name) | |
| except Exception: | |
| pass | |
| del store | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| _remove_tree_within_root(notebook_root, notebook_path) | |
| upload_path = UPLOADS_ROOT / f"notebook_{notebook_id}" | |
| _remove_tree_within_root(UPLOADS_ROOT, upload_path) | |
| def health_check() -> dict[str, str]: | |
| return {"status": "ok"} | |
| def root() -> dict[str, str]: | |
| return { | |
| "message": "NotebookLM Clone API", | |
| "health": "/health", | |
| "docs": "/docs", | |
| } | |
| def health_db(db: Session = Depends(get_db)) -> dict[str, str]: | |
| db.execute(text("SELECT 1")) | |
| return {"status": "ok", "database": "connected"} | |
| def auth_status(request: Request) -> AuthStatusResponse: | |
| return _auth_status_payload(request) | |
| def auth_dev_login( | |
| payload: DevLoginRequest, | |
| request: Request, | |
| db: Session = Depends(get_db), | |
| ) -> AuthStatusResponse: | |
| if get_auth_mode() != AUTH_MODE_DEV: | |
| raise HTTPException(status_code=400, detail="Dev login is disabled for this deployment.") | |
| default_email = os.getenv("AUTH_DEV_EMAIL", "dev@example.com") | |
| default_name = os.getenv("AUTH_DEV_DISPLAY_NAME", "Dev User") | |
| email = (payload.email or default_email).strip().lower() | |
| if not email: | |
| raise HTTPException(status_code=400, detail="A valid email is required.") | |
| display_name = (payload.display_name or default_name).strip() or None | |
| user = crud.get_or_create_user_by_email( | |
| db=db, | |
| email=email, | |
| display_name=display_name, | |
| ) | |
| set_session_user( | |
| request, | |
| CurrentUser( | |
| id=user.id, | |
| email=user.email, | |
| display_name=user.display_name, | |
| avatar_url=user.avatar_url, | |
| ), | |
| ) | |
| return _auth_status_payload(request) | |
| def auth_logout(request: Request) -> AuthStatusResponse: | |
| clear_session_user(request) | |
| return _auth_status_payload(request) | |
| def auth_bridge_exchange( | |
| payload: AuthBridgeExchangeRequest, | |
| request: Request, | |
| db: Session = Depends(get_db), | |
| ) -> AuthStatusResponse: | |
| if get_auth_mode() != AUTH_MODE_HF: | |
| raise HTTPException(status_code=400, detail="Auth bridge is only available in hf_oauth mode.") | |
| try: | |
| bridged_identity = decode_auth_bridge_token(payload.token) | |
| except AuthBridgeTokenError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| user = crud.get_or_create_user_by_email( | |
| db=db, | |
| email=bridged_identity.email, | |
| display_name=bridged_identity.display_name, | |
| avatar_url=bridged_identity.avatar_url, | |
| ) | |
| set_session_user( | |
| request, | |
| CurrentUser( | |
| id=user.id, | |
| email=user.email, | |
| display_name=user.display_name, | |
| avatar_url=user.avatar_url, | |
| ), | |
| ) | |
| return _auth_status_payload(request) | |
| def auth_login(request: Request) -> RedirectResponse: | |
| if get_auth_mode() != AUTH_MODE_HF: | |
| raise HTTPException(status_code=400, detail="HF OAuth is not enabled.") | |
| state = generate_oauth_state() | |
| request.session["oauth_state"] = state | |
| return RedirectResponse( | |
| url=build_hf_authorize_url(redirect_uri=_auth_callback_url(request), state=state), | |
| status_code=status.HTTP_302_FOUND, | |
| ) | |
| async def auth_callback(request: Request, db: Session = Depends(get_db)) -> RedirectResponse: | |
| if get_auth_mode() != AUTH_MODE_HF: | |
| raise HTTPException(status_code=400, detail="HF OAuth is not enabled.") | |
| expected_state = request.session.get("oauth_state") | |
| state = request.query_params.get("state") | |
| code = request.query_params.get("code") | |
| if not state: | |
| raise HTTPException(status_code=400, detail="Invalid OAuth state.") | |
| if expected_state: | |
| if state != expected_state: | |
| raise HTTPException(status_code=400, detail="Invalid OAuth state.") | |
| elif not is_valid_oauth_state(state): | |
| raise HTTPException(status_code=400, detail="Invalid OAuth state.") | |
| if not code: | |
| raise HTTPException(status_code=400, detail="Missing OAuth code.") | |
| try: | |
| identity = await exchange_code_for_hf_user(code=code, redirect_uri=_auth_callback_url(request)) | |
| except HFOAuthError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| user = crud.get_or_create_user_by_email( | |
| db=db, | |
| email=identity["email"], | |
| display_name=identity.get("display_name"), | |
| avatar_url=identity.get("avatar_url"), | |
| ) | |
| set_session_user( | |
| request, | |
| CurrentUser( | |
| id=user.id, | |
| email=user.email, | |
| display_name=user.display_name, | |
| avatar_url=user.avatar_url, | |
| ), | |
| ) | |
| request.session.pop("oauth_state", None) | |
| bridge_token = generate_auth_bridge_token( | |
| CurrentUser( | |
| id=user.id, | |
| email=user.email, | |
| display_name=user.display_name, | |
| avatar_url=user.avatar_url, | |
| ) | |
| ) | |
| redirect_url = _append_query_param( | |
| os.getenv("AUTH_SUCCESS_REDIRECT_URL", "/"), | |
| "auth_bridge", | |
| bridge_token, | |
| ) | |
| return RedirectResponse( | |
| url=redirect_url, | |
| status_code=status.HTTP_302_FOUND, | |
| ) | |
| def create_notebook( | |
| payload: NotebookCreateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> NotebookResponse: | |
| notebook = crud.create_notebook(db=db, owner_user_id=current_user.id, title=payload.title) | |
| return NotebookResponse(id=notebook.id, owner_user_id=notebook.owner_user_id, title=notebook.title) | |
| def get_notebooks( | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> list[NotebookResponse]: | |
| notebooks = crud.list_notebooks(db=db, owner_user_id=current_user.id) | |
| return [ | |
| NotebookResponse(id=n.id, owner_user_id=n.owner_user_id, title=n.title) for n in notebooks | |
| ] | |
| def rename_notebook( | |
| notebook_id: int, | |
| payload: NotebookUpdateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> NotebookResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| updated = crud.update_notebook_title(db=db, notebook=notebook, title=payload.title.strip()) | |
| return NotebookResponse( | |
| id=updated.id, | |
| owner_user_id=updated.owner_user_id, | |
| title=updated.title, | |
| ) | |
| def delete_notebook( | |
| notebook_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> NotebookDeleteResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| try: | |
| _cleanup_notebook_storage(owner_user_id=current_user.id, notebook_id=notebook_id) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=f"Failed to delete notebook storage: {exc}") from exc | |
| crud.delete_notebook(db=db, notebook=notebook) | |
| return NotebookDeleteResponse(status="deleted", notebook_id=notebook_id) | |
| async def create_source_for_notebook( | |
| notebook_id: int, | |
| payload: SourceCreateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> SourceResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| source_type = payload.type.strip().lower() | |
| source_url = payload.url | |
| if source_type == "url": | |
| if not source_url or not source_url.strip(): | |
| raise HTTPException(status_code=400, detail="URL is required when source type is 'url'.") | |
| try: | |
| source_url = validate_ingestion_url(source_url) | |
| except URLValidationError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| source = crud.create_source( | |
| db=db, | |
| notebook_id=notebook_id, | |
| source_type=source_type, | |
| title=payload.title, | |
| original_name=payload.original_name, | |
| url=source_url, | |
| storage_path=payload.storage_path, | |
| status=payload.status, | |
| ) | |
| if source_type == "url": | |
| crud.update_source_status(db=db, source_id=source.id, status="processing") | |
| try: | |
| ingested_chunk_count = await run_in_threadpool( | |
| ingest_source, source=source, owner_user_id=current_user.id | |
| ) | |
| final_status = "ready" if ingested_chunk_count > 0 else "failed" | |
| source = crud.update_source_status( | |
| db=db, | |
| source_id=source.id, | |
| status=final_status, | |
| ingested_at=datetime.now(timezone.utc) if final_status == "ready" else None, | |
| ) or source | |
| except Exception as exc: | |
| crud.update_source_status(db=db, source_id=source.id, status="failed") | |
| raise HTTPException(status_code=500, detail=f"Ingestion failed: {exc}") from exc | |
| return SourceResponse( | |
| id=source.id, | |
| notebook_id=source.notebook_id, | |
| type=source.type, | |
| title=source.title, | |
| original_name=source.original_name, | |
| url=source.url, | |
| storage_path=source.storage_path, | |
| status=source.status, | |
| ingested_at=source.ingested_at, | |
| ) | |
| async def upload_source_for_notebook( | |
| notebook_id: int, | |
| title: str | None = Form(None), | |
| status: str = Form("pending"), | |
| file: UploadFile = File(...), | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> SourceResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| destination = _resolve_notebook_upload_path(notebook_id=notebook_id, filename=file.filename) | |
| content = await file.read() | |
| destination.write_bytes(content) | |
| original_name = Path(str(file.filename or destination.name)).name | |
| source_title = title or original_name or destination.name | |
| source = crud.create_source( | |
| db=db, | |
| notebook_id=notebook_id, | |
| source_type="file", | |
| title=source_title, | |
| original_name=original_name, | |
| url=None, | |
| storage_path=str(destination), | |
| status=status, | |
| ) | |
| crud.update_source_status(db=db, source_id=source.id, status="processing") | |
| try: | |
| ingested_chunk_count = await run_in_threadpool( | |
| ingest_source, source=source, owner_user_id=current_user.id | |
| ) | |
| final_status = "ready" if ingested_chunk_count > 0 else "failed" | |
| source = crud.update_source_status( | |
| db=db, | |
| source_id=source.id, | |
| status=final_status, | |
| ingested_at=datetime.now(timezone.utc) if final_status == "ready" else None, | |
| ) or source | |
| except Exception as exc: | |
| crud.update_source_status(db=db, source_id=source.id, status="failed") | |
| raise HTTPException(status_code=500, detail=f"Ingestion failed: {exc}") from exc | |
| return SourceResponse( | |
| id=source.id, | |
| notebook_id=source.notebook_id, | |
| type=source.type, | |
| title=source.title, | |
| original_name=source.original_name, | |
| url=source.url, | |
| storage_path=source.storage_path, | |
| status=source.status, | |
| ingested_at=source.ingested_at, | |
| ) | |
| def list_sources_for_notebook( | |
| notebook_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> list[SourceResponse]: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| sources = crud.list_sources_for_notebook(db=db, notebook_id=notebook_id) | |
| return [ | |
| SourceResponse( | |
| id=s.id, | |
| notebook_id=s.notebook_id, | |
| type=s.type, | |
| title=s.title, | |
| original_name=s.original_name, | |
| url=s.url, | |
| storage_path=s.storage_path, | |
| status=s.status, | |
| ingested_at=s.ingested_at, | |
| ) | |
| for s in sources | |
| ] | |
| def list_sources_placeholder() -> dict[str, str]: | |
| return {"message": "Use /notebooks/{notebook_id}/sources endpoints."} | |
| def create_thread_for_notebook( | |
| notebook_id: int, | |
| payload: ThreadCreateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ThreadResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| thread = crud.create_chat_thread(db=db, notebook_id=notebook_id, title=payload.title) | |
| return ThreadResponse( | |
| id=thread.id, notebook_id=thread.notebook_id, title=thread.title, created_at=thread.created_at | |
| ) | |
| def list_threads_for_notebook( | |
| notebook_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> list[ThreadResponse]: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| threads = crud.list_chat_threads(db=db, notebook_id=notebook_id) | |
| return [ | |
| ThreadResponse(id=t.id, notebook_id=t.notebook_id, title=t.title, created_at=t.created_at) | |
| for t in threads | |
| ] | |
| def list_messages_for_thread( | |
| thread_id: int, | |
| notebook_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> list[MessageResponse]: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| thread = crud.get_thread_for_notebook(db=db, notebook_id=notebook_id, thread_id=thread_id) | |
| if thread is None: | |
| raise HTTPException(status_code=404, detail="Thread not found for this notebook.") | |
| messages = crud.list_messages_for_thread(db=db, thread_id=thread_id) | |
| citations_by_message = crud.list_message_citations_for_thread(db=db, thread_id=thread_id) | |
| return [ | |
| MessageResponse( | |
| id=m.id, | |
| thread_id=m.thread_id, | |
| role=m.role, | |
| content=m.content, | |
| created_at=m.created_at, | |
| citations=[ | |
| CitationResponse( | |
| source_title=entry.get("source_title"), | |
| source_id=int(entry.get("source_id", 0)), | |
| chunk_ref=(str(entry.get("chunk_ref")) if entry.get("chunk_ref") else None), | |
| quote=(str(entry.get("quote")) if entry.get("quote") else None), | |
| score=(float(entry["score"]) if entry.get("score") is not None else None), | |
| ) | |
| for entry in citations_by_message.get(m.id, []) | |
| if int(entry.get("source_id", 0)) > 0 | |
| ], | |
| ) | |
| for m in messages | |
| ] | |
| def chat_on_thread( | |
| thread_id: int, | |
| payload: ChatRequest, | |
| notebook_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ChatResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, notebook_id=notebook_id, owner_user_id=current_user.id | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| thread = crud.get_thread_for_notebook(db=db, notebook_id=notebook_id, thread_id=thread_id) | |
| if thread is None: | |
| raise HTTPException(status_code=404, detail="Thread not found for this notebook.") | |
| prior_messages = crud.list_messages_for_thread(db=db, thread_id=thread_id) | |
| user_message = crud.create_message(db=db, thread_id=thread_id, role="user", content=payload.question) | |
| retrieval_rows = query_notebook_chunks( | |
| owner_user_id=current_user.id, | |
| notebook_id=notebook_id, | |
| query_text=payload.question, | |
| top_k=payload.top_k, | |
| ) | |
| context_blocks: list[str] = [] | |
| citations: list[CitationResponse] = [] | |
| citation_rows: list[dict[str, int | str | float | None]] = [] | |
| for row in retrieval_rows: | |
| doc = row.get("document", "") | |
| meta = row.get("metadata", {}) if isinstance(row.get("metadata"), dict) else {} | |
| score = row.get("score") | |
| try: | |
| source_id = int(meta.get("source_id", 0)) | |
| except (TypeError, ValueError): | |
| source_id = 0 | |
| chunk_index = meta.get("chunk_index") | |
| source_title = ( | |
| str(meta.get("source_title")) if isinstance(meta, dict) and meta.get("source_title") else None | |
| ) | |
| chunk_ref = f"source_{source_id}_chunk_{chunk_index}" if source_id and chunk_index is not None else None | |
| context_blocks.append( | |
| f"[source_title={source_title or 'Unknown'}, source_id={source_id}, chunk_index={chunk_index}]\n{doc}" | |
| ) | |
| citation = CitationResponse( | |
| source_title=source_title, | |
| source_id=source_id, | |
| chunk_ref=chunk_ref, | |
| quote=(doc[:300] if isinstance(doc, str) else None), | |
| score=(float(score) if score is not None else None), | |
| ) | |
| citations.append(citation) | |
| if source_id: | |
| citation_rows.append( | |
| { | |
| "source_id": source_id, | |
| "chunk_ref": chunk_ref, | |
| "quote": citation.quote, | |
| "score": citation.score, | |
| } | |
| ) | |
| if context_blocks: | |
| system_prompt = build_rag_system_prompt() | |
| conversation_history = _build_conversation_history(prior_messages) | |
| user_prompt = build_rag_user_prompt( | |
| question=payload.question, | |
| context_blocks=context_blocks, | |
| conversation_history=conversation_history, | |
| ) | |
| try: | |
| answer = generate_chat_completion(system_prompt=system_prompt, user_prompt=user_prompt) | |
| except LLMConfigError as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=f"LLM generation failed: {exc}") from exc | |
| else: | |
| answer = "I do not have enough indexed context to answer this question yet." | |
| assistant_message = crud.create_message( | |
| db=db, thread_id=thread_id, role="assistant", content=answer | |
| ) | |
| if citation_rows: | |
| crud.create_message_citations(db=db, message_id=assistant_message.id, citations=citation_rows) | |
| return ChatResponse( | |
| user_message=MessageResponse( | |
| id=user_message.id, | |
| thread_id=user_message.thread_id, | |
| role=user_message.role, | |
| content=user_message.content, | |
| created_at=user_message.created_at, | |
| citations=[], | |
| ), | |
| assistant_message=MessageResponse( | |
| id=assistant_message.id, | |
| thread_id=assistant_message.thread_id, | |
| role=assistant_message.role, | |
| content=assistant_message.content, | |
| created_at=assistant_message.created_at, | |
| citations=citations, | |
| ), | |
| citations=citations, | |
| ) | |
| def list_threads_placeholder() -> dict[str, str]: | |
| return {"message": "Use /notebooks/{notebook_id}/threads endpoints."} | |
| # ── Artifact helpers ────────────────────────────────────────────────────────── | |
| def _artifact_response(artifact) -> ArtifactResponse: | |
| return ArtifactResponse( | |
| id=artifact.id, | |
| notebook_id=artifact.notebook_id, | |
| type=artifact.type, | |
| title=artifact.title, | |
| status=artifact.status, | |
| content=artifact.content, | |
| file_path=artifact.file_path, | |
| metadata=artifact.artifact_metadata, | |
| error_message=artifact.error_message, | |
| created_at=artifact.created_at, | |
| generated_at=artifact.generated_at, | |
| ) | |
| def _run_podcast_background( | |
| artifact_id: int, | |
| user_id: int, | |
| notebook_id: int, | |
| duration: str, | |
| topic_focus: str | None, | |
| ) -> None: | |
| """Background task: generate podcast and update the artifact record.""" | |
| db = SessionLocal() | |
| try: | |
| crud.update_artifact(db, artifact_id, status="processing") | |
| generator = PodcastGenerator() | |
| result = generator.generate_podcast( | |
| user_id=str(user_id), | |
| notebook_id=str(notebook_id), | |
| duration_target=duration, | |
| topic_focus=topic_focus, | |
| ) | |
| if "error" in result: | |
| transcript_markdown = "" | |
| transcript_path = None | |
| transcript = result.get("transcript") | |
| if isinstance(transcript, list) and transcript: | |
| transcript_markdown = generator.format_transcript_markdown(result) | |
| transcript_path = generator.save_transcript(result, str(user_id), str(notebook_id)) | |
| crud.update_artifact( | |
| db, | |
| artifact_id, | |
| status="failed", | |
| error_message=result["error"], | |
| content=(transcript_markdown or None), | |
| metadata={ | |
| "audio_path": None, | |
| "transcript_path": transcript_path, | |
| **( | |
| result.get("metadata", {}) | |
| if isinstance(result.get("metadata"), dict) | |
| else {} | |
| ), | |
| }, | |
| ) | |
| else: | |
| transcript_markdown = generator.format_transcript_markdown(result) | |
| transcript_path = generator.save_transcript(result, str(user_id), str(notebook_id)) | |
| audio_path = result.get("audio_path") | |
| crud.update_artifact( | |
| db, | |
| artifact_id, | |
| status="ready", | |
| content=transcript_markdown, | |
| file_path=(str(audio_path) if audio_path else None), | |
| metadata={ | |
| "audio_path": (str(audio_path) if audio_path else None), | |
| "transcript_path": transcript_path, | |
| **( | |
| result.get("metadata", {}) | |
| if isinstance(result.get("metadata"), dict) | |
| else {} | |
| ), | |
| }, | |
| ) | |
| except Exception as exc: | |
| crud.update_artifact(db, artifact_id, status="failed", error_message=str(exc)) | |
| finally: | |
| db.close() | |
| # ── Artifact endpoints ──────────────────────────────────────────────────────── | |
| async def generate_report_for_notebook( | |
| notebook_id: int, | |
| payload: ReportGenerateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ArtifactResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| detail_level = payload.detail_level.strip().lower() | |
| if detail_level not in {"short", "medium", "long"}: | |
| raise HTTPException(status_code=400, detail="detail_level must be one of: short, medium, long") | |
| artifact = crud.create_artifact( | |
| db=db, | |
| notebook_id=notebook_id, | |
| artifact_type="report", | |
| title=payload.title or f"Report – {detail_level}", | |
| metadata={ | |
| "detail_level": detail_level, | |
| "topic_focus": payload.topic_focus, | |
| }, | |
| ) | |
| crud.update_artifact(db, artifact.id, status="processing") | |
| try: | |
| generator = ReportGenerator() | |
| result = await run_in_threadpool( | |
| generator.generate_report, | |
| user_id=str(current_user.id), | |
| notebook_id=str(notebook_id), | |
| detail_level=detail_level, | |
| topic_focus=payload.topic_focus, | |
| ) | |
| except Exception as exc: | |
| artifact = crud.update_artifact(db, artifact.id, status="failed", error_message=str(exc)) | |
| return _artifact_response(artifact) | |
| if "error" in result: | |
| artifact = crud.update_artifact(db, artifact.id, status="failed", error_message=result["error"]) | |
| else: | |
| content = str(result.get("content", "")).strip() | |
| report_path = generator.save_report(content, str(current_user.id), str(notebook_id)) | |
| artifact = crud.update_artifact( | |
| db, | |
| artifact.id, | |
| status="ready", | |
| content=content, | |
| file_path=report_path, | |
| ) | |
| return _artifact_response(artifact) | |
| async def generate_quiz_for_notebook( | |
| notebook_id: int, | |
| payload: QuizGenerateRequest, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ArtifactResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| artifact = crud.create_artifact( | |
| db=db, | |
| notebook_id=notebook_id, | |
| artifact_type="quiz", | |
| title=payload.title or f"Quiz – {payload.difficulty} ({payload.num_questions}q)", | |
| metadata={ | |
| "num_questions": payload.num_questions, | |
| "difficulty": payload.difficulty, | |
| "topic_focus": payload.topic_focus, | |
| }, | |
| ) | |
| crud.update_artifact(db, artifact.id, status="processing") | |
| try: | |
| generator = QuizGenerator() | |
| result = await run_in_threadpool( | |
| generator.generate_quiz, | |
| user_id=str(current_user.id), | |
| notebook_id=str(notebook_id), | |
| num_questions=payload.num_questions, | |
| difficulty=payload.difficulty, | |
| topic_focus=payload.topic_focus, | |
| ) | |
| except Exception as exc: | |
| crud.update_artifact(db, artifact.id, status="failed", error_message=str(exc)) | |
| raise HTTPException(status_code=500, detail=f"Quiz generation failed: {exc}") from exc | |
| if "error" in result: | |
| artifact = crud.update_artifact(db, artifact.id, status="failed", error_message=result["error"]) | |
| else: | |
| quiz_markdown = generator.format_quiz_markdown(result, title=payload.title or "Quiz") | |
| quiz_path = generator.save_quiz(quiz_markdown, str(current_user.id), str(notebook_id)) | |
| artifact = crud.update_artifact( | |
| db, | |
| artifact.id, | |
| status="ready", | |
| content=quiz_markdown, | |
| file_path=quiz_path, | |
| ) | |
| return _artifact_response(artifact) | |
| def generate_podcast_for_notebook( | |
| notebook_id: int, | |
| payload: PodcastGenerateRequest, | |
| background_tasks: BackgroundTasks, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ArtifactResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| artifact = crud.create_artifact( | |
| db=db, | |
| notebook_id=notebook_id, | |
| artifact_type="podcast", | |
| title=payload.title or f"Podcast – {payload.duration}", | |
| metadata={ | |
| "duration": payload.duration, | |
| "topic_focus": payload.topic_focus, | |
| }, | |
| ) | |
| background_tasks.add_task( | |
| _run_podcast_background, | |
| artifact_id=artifact.id, | |
| user_id=current_user.id, | |
| notebook_id=notebook_id, | |
| duration=payload.duration, | |
| topic_focus=payload.topic_focus, | |
| ) | |
| return _artifact_response(artifact) | |
| def list_artifacts_for_notebook( | |
| notebook_id: int, | |
| artifact_type: str | None = None, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> list[ArtifactResponse]: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| artifacts = crud.list_artifacts(db=db, notebook_id=notebook_id, artifact_type=artifact_type) | |
| return [_artifact_response(a) for a in artifacts] | |
| def get_artifact_for_notebook( | |
| notebook_id: int, | |
| artifact_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> ArtifactResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| artifact = crud.get_artifact(db=db, artifact_id=artifact_id) | |
| if artifact is None or artifact.notebook_id != notebook_id: | |
| raise HTTPException(status_code=404, detail="Artifact not found.") | |
| return _artifact_response(artifact) | |
| def download_podcast_audio( | |
| notebook_id: int, | |
| artifact_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: CurrentUser = Depends(require_current_user), | |
| ) -> FileResponse: | |
| notebook = crud.get_notebook_for_user( | |
| db=db, | |
| notebook_id=notebook_id, | |
| owner_user_id=current_user.id, | |
| ) | |
| if notebook is None: | |
| raise HTTPException(status_code=404, detail="Notebook not found for this user.") | |
| artifact = crud.get_artifact(db=db, artifact_id=artifact_id) | |
| if artifact is None or artifact.notebook_id != notebook_id: | |
| raise HTTPException(status_code=404, detail="Artifact not found.") | |
| if artifact.type != "podcast": | |
| raise HTTPException(status_code=400, detail="Artifact is not a podcast.") | |
| if artifact.status != "ready": | |
| raise HTTPException(status_code=409, detail=f"Podcast not ready yet (status: {artifact.status}).") | |
| if not artifact.file_path or not Path(artifact.file_path).exists(): | |
| raise HTTPException(status_code=404, detail="Audio file not found on disk.") | |
| return FileResponse( | |
| path=artifact.file_path, | |
| media_type="audio/mpeg", | |
| filename=Path(artifact.file_path).name, | |
| ) | |
| app.include_router(auth_router) | |
| app.include_router(notebooks_router) | |
| app.include_router(sources_router) | |
| app.include_router(threads_router) | |