from __future__ import annotations from dataclasses import dataclass from app.analysis.sentence_split import split_sentences from app.core.config import Settings from app.core.runtime import load_model_bundle from app.core.runtime_pipeline import analyze_generation_result from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult from app.generation.service import generate_answer_and_trace from app.storage.repository import SessionRecord, SessionRepository from app.workers.jobs import JobRunner class SessionLimitError(RuntimeError): pass class SessionAccessError(PermissionError): pass @dataclass(slots=True) class SessionService: settings: Settings repository: SessionRepository jobs: JobRunner def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord: if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs: raise SessionLimitError("The service queue is full. Try again after a few minutes.") if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user: raise SessionLimitError("You already have the maximum number of active analysis jobs.") model_name = request.model_name or self.settings.model_name session = self.repository.create_session( question=request.question, model_name=model_name, owner_id=owner_id, owner_name=owner_name, ) self.jobs.submit(self._run_session_pipeline, session.id, request) return session def start_analysis( self, session_id: str, *, owner_id: str, request: AnalysisRequest | None = None, ) -> SessionRecord: session = self.repository.get_session(session_id) self._assert_owner(session, owner_id) effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name) self.jobs.submit(self._run_analysis_only, session_id, effective_request) return session def get_session_payload(self, session_id: str, *, owner_id: str) -> dict: session = self.repository.get_session(session_id) self._assert_owner(session, owner_id) return self.repository.list_session_payload(session_id) def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]: return self.repository.list_sessions_for_owner(owner_id, limit=limit) def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult: payload = self.get_session_payload(session_id, owner_id=owner_id) analysis = payload.get("analysis") if analysis is None: raise KeyError(session_id) return AnalysisResult.model_validate(analysis) @staticmethod def _assert_owner(session: SessionRecord, owner_id: str) -> None: if session.owner_id != owner_id: raise SessionAccessError("Session belongs to a different user.") def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None: try: self.repository.update_status(session_id, status="generating") bundle = load_model_bundle( request.model_name or self.settings.model_name, device_preference=request.device_preference or self.settings.device_preference, dtype_preference=request.dtype_preference or self.settings.dtype_preference, attn_implementation=request.attn_implementation or self.settings.attn_implementation, trust_remote_code=( self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code ), low_cpu_mem_usage=( self.settings.low_cpu_mem_usage if request.low_cpu_mem_usage is None else request.low_cpu_mem_usage ), ) generation = generate_answer_and_trace( question=request.question, model_name=bundle.model_name, model=bundle.model, tokenizer=bundle.tokenizer, max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_p=request.top_p, ) sentences = [span.text for span in split_sentences(generation.normalized_trace_text)] self.repository.save_generation_result(session_id, generation, sentences) self.repository.update_status(session_id, status="answer_ready") self._run_analysis_only(session_id, request, generation=generation) except Exception as exc: self.repository.update_status(session_id, status="failed", error=str(exc)) def _run_analysis_only( self, session_id: str, request: AnalysisRequest, *, generation=None, ) -> None: try: self.repository.update_status(session_id, status="analyzing") if generation is None: payload = self.repository.list_session_payload(session_id) if payload.get("generation_metadata") is not None: generation = GenerationResult( question=payload["question"], model_name=payload["model_name"], answer=payload["answer"], raw_generation_text=payload.get("raw_generation_text", ""), raw_trace_text=payload["raw_trace_text"], normalized_trace_text=payload["normalized_trace_text"], generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]), ) result = analyze_generation_result( question=request.question, generation=generation, model_name=request.model_name or self.settings.model_name, take_log=request.take_log, max_sentences=request.max_sentences, max_trace_tokens=request.max_trace_tokens, validate_top_k=request.validate_top_k, device_preference=request.device_preference or self.settings.device_preference, dtype_preference=request.dtype_preference or self.settings.dtype_preference, attn_implementation=request.attn_implementation or self.settings.attn_implementation, trust_remote_code=( self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code ), low_cpu_mem_usage=( self.settings.low_cpu_mem_usage if request.low_cpu_mem_usage is None else request.low_cpu_mem_usage ), max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_p=request.top_p, ) self.repository.save_analysis_result(session_id, result) self.repository.update_status(session_id, status="completed") except Exception as exc: self.repository.update_status(session_id, status="failed", error=str(exc))