from __future__ import annotations from collections import defaultdict from datetime import datetime, timezone from sqlalchemy.orm import Session from data.models import Artifact, ChatThread, Message, MessageCitation, Notebook, Source, User def get_or_create_user( db: Session, user_id: int = 1, email: str = "dev@example.com", display_name: str = "Dev User", ) -> User: user = db.get(User, user_id) if user: return user user = User(id=user_id, email=email, display_name=display_name) db.add(user) db.commit() db.refresh(user) return user def get_user_by_id(db: Session, user_id: int) -> User | None: return db.get(User, user_id) def get_or_create_user_by_email( db: Session, email: str, display_name: str | None = None, avatar_url: str | None = None, ) -> User: normalized_email = email.strip().lower() user = db.query(User).filter(User.email == normalized_email).first() if user: changed = False if display_name and user.display_name != display_name: user.display_name = display_name changed = True if avatar_url and user.avatar_url != avatar_url: user.avatar_url = avatar_url changed = True if changed: db.commit() db.refresh(user) return user user = User( email=normalized_email, display_name=display_name, avatar_url=avatar_url, ) db.add(user) db.commit() db.refresh(user) return user def create_notebook(db: Session, owner_user_id: int, title: str) -> Notebook: notebook = Notebook(owner_user_id=owner_user_id, title=title) db.add(notebook) db.commit() db.refresh(notebook) return notebook def list_notebooks(db: Session, owner_user_id: int) -> list[Notebook]: return ( db.query(Notebook) .filter(Notebook.owner_user_id == owner_user_id) .order_by(Notebook.created_at.desc()) .all() ) def get_notebook_for_user(db: Session, notebook_id: int, owner_user_id: int) -> Notebook | None: return ( db.query(Notebook) .filter(Notebook.id == notebook_id, Notebook.owner_user_id == owner_user_id) .first() ) def update_notebook_title(db: Session, notebook: Notebook, title: str) -> Notebook: notebook.title = title db.commit() db.refresh(notebook) return notebook def delete_notebook(db: Session, notebook: Notebook) -> None: db.delete(notebook) db.commit() def create_source( db: Session, notebook_id: int, source_type: str, title: str | None, original_name: str | None, url: str | None, storage_path: str | None, status: str = "pending", ) -> Source: source = Source( notebook_id=notebook_id, type=source_type, title=title, original_name=original_name, url=url, storage_path=storage_path, status=status, ) db.add(source) db.commit() db.refresh(source) return source def list_sources_for_notebook(db: Session, notebook_id: int) -> list[Source]: return ( db.query(Source) .filter(Source.notebook_id == notebook_id) .order_by(Source.id.desc()) .all() ) def update_source_status( db: Session, source_id: int, status: str, ingested_at: datetime | None = None, ) -> Source | None: source = db.get(Source, source_id) if source is None: return None source.status = status if ingested_at is not None: source.ingested_at = ingested_at db.commit() db.refresh(source) return source def create_chat_thread(db: Session, notebook_id: int, title: str | None = None) -> ChatThread: thread = ChatThread(notebook_id=notebook_id, title=title) db.add(thread) db.commit() db.refresh(thread) return thread def list_chat_threads(db: Session, notebook_id: int) -> list[ChatThread]: return ( db.query(ChatThread) .filter(ChatThread.notebook_id == notebook_id) .order_by(ChatThread.created_at.desc()) .all() ) def get_thread_for_notebook(db: Session, notebook_id: int, thread_id: int) -> ChatThread | None: return ( db.query(ChatThread) .filter(ChatThread.id == thread_id, ChatThread.notebook_id == notebook_id) .first() ) def create_message(db: Session, thread_id: int, role: str, content: str) -> Message: message = Message(thread_id=thread_id, role=role, content=content) db.add(message) db.commit() db.refresh(message) return message def list_messages_for_thread(db: Session, thread_id: int) -> list[Message]: return ( db.query(Message) .filter(Message.thread_id == thread_id) .order_by(Message.created_at.asc()) .all() ) def create_message_citations( db: Session, message_id: int, citations: list[dict[str, int | str | float | None]], ) -> list[MessageCitation]: rows: list[MessageCitation] = [] for item in citations: row = MessageCitation( message_id=message_id, source_id=int(item["source_id"]), chunk_ref=item.get("chunk_ref"), # type: ignore[arg-type] quote=item.get("quote"), # type: ignore[arg-type] score=float(item["score"]) if item.get("score") is not None else None, ) db.add(row) rows.append(row) db.commit() for row in rows: db.refresh(row) return rows def list_message_citations_for_thread( db: Session, thread_id: int ) -> dict[int, list[dict[str, int | str | float | None]]]: rows = ( db.query(MessageCitation, Source.title) .join(Source, Source.id == MessageCitation.source_id) .join(Message, Message.id == MessageCitation.message_id) .filter(Message.thread_id == thread_id) .order_by(MessageCitation.id.asc()) .all() ) citations_by_message: dict[int, list[dict[str, int | str | float | None]]] = defaultdict(list) for citation, source_title in rows: citations_by_message[int(citation.message_id)].append( { "source_id": int(citation.source_id), "source_title": source_title, "chunk_ref": citation.chunk_ref, "quote": citation.quote, "score": citation.score, } ) return dict(citations_by_message) def get_artifact(db: Session, artifact_id: int) -> Artifact | None: return db.get(Artifact, artifact_id) def create_artifact( db: Session, notebook_id: int, artifact_type: str, title: str | None = None, metadata: dict | None = None, ) -> Artifact: artifact = Artifact( notebook_id=notebook_id, type=artifact_type, title=title, artifact_metadata=metadata or {}, status="pending", ) db.add(artifact) db.commit() db.refresh(artifact) return artifact def list_artifacts(db: Session, notebook_id: int, artifact_type: str | None = None) -> list[Artifact]: query = db.query(Artifact).filter(Artifact.notebook_id == notebook_id) if artifact_type: query = query.filter(Artifact.type == artifact_type) return query.order_by(Artifact.created_at.desc()).all() def update_artifact( db: Session, artifact_id: int, status: str, content: str | None = None, file_path: str | None = None, error_message: str | None = None, metadata: dict | None = None, ) -> Artifact | None: artifact = db.get(Artifact, artifact_id) if not artifact: return None artifact.status = status if content is not None: artifact.content = content if file_path is not None: artifact.file_path = file_path if error_message is not None: artifact.error_message = error_message if metadata is not None: merged = dict(artifact.artifact_metadata or {}) merged.update(metadata) artifact.artifact_metadata = merged if status == "ready": artifact.generated_at = datetime.now(timezone.utc) db.commit() db.refresh(artifact) return artifact