cot-anc / app /services /sessions.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
from dataclasses import dataclass
from app.analysis.sentence_split import split_sentences
from app.core.config import Settings
from app.core.runtime import load_model_bundle
from app.core.runtime_pipeline import analyze_generation_result
from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult
from app.generation.service import generate_answer_and_trace
from app.storage.repository import SessionRecord, SessionRepository
from app.workers.jobs import JobRunner
class SessionLimitError(RuntimeError):
pass
class SessionAccessError(PermissionError):
pass
@dataclass(slots=True)
class SessionService:
settings: Settings
repository: SessionRepository
jobs: JobRunner
def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord:
if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs:
raise SessionLimitError("The service queue is full. Try again after a few minutes.")
if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user:
raise SessionLimitError("You already have the maximum number of active analysis jobs.")
model_name = request.model_name or self.settings.model_name
session = self.repository.create_session(
question=request.question,
model_name=model_name,
owner_id=owner_id,
owner_name=owner_name,
)
self.jobs.submit(self._run_session_pipeline, session.id, request)
return session
def start_analysis(
self,
session_id: str,
*,
owner_id: str,
request: AnalysisRequest | None = None,
) -> SessionRecord:
session = self.repository.get_session(session_id)
self._assert_owner(session, owner_id)
effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name)
self.jobs.submit(self._run_analysis_only, session_id, effective_request)
return session
def get_session_payload(self, session_id: str, *, owner_id: str) -> dict:
session = self.repository.get_session(session_id)
self._assert_owner(session, owner_id)
return self.repository.list_session_payload(session_id)
def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]:
return self.repository.list_sessions_for_owner(owner_id, limit=limit)
def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult:
payload = self.get_session_payload(session_id, owner_id=owner_id)
analysis = payload.get("analysis")
if analysis is None:
raise KeyError(session_id)
return AnalysisResult.model_validate(analysis)
@staticmethod
def _assert_owner(session: SessionRecord, owner_id: str) -> None:
if session.owner_id != owner_id:
raise SessionAccessError("Session belongs to a different user.")
def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None:
try:
self.repository.update_status(session_id, status="generating")
bundle = load_model_bundle(
request.model_name or self.settings.model_name,
device_preference=request.device_preference or self.settings.device_preference,
dtype_preference=request.dtype_preference or self.settings.dtype_preference,
attn_implementation=request.attn_implementation or self.settings.attn_implementation,
trust_remote_code=(
self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
),
low_cpu_mem_usage=(
self.settings.low_cpu_mem_usage
if request.low_cpu_mem_usage is None
else request.low_cpu_mem_usage
),
)
generation = generate_answer_and_trace(
question=request.question,
model_name=bundle.model_name,
model=bundle.model,
tokenizer=bundle.tokenizer,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
)
sentences = [span.text for span in split_sentences(generation.normalized_trace_text)]
self.repository.save_generation_result(session_id, generation, sentences)
self.repository.update_status(session_id, status="answer_ready")
self._run_analysis_only(session_id, request, generation=generation)
except Exception as exc:
self.repository.update_status(session_id, status="failed", error=str(exc))
def _run_analysis_only(
self,
session_id: str,
request: AnalysisRequest,
*,
generation=None,
) -> None:
try:
self.repository.update_status(session_id, status="analyzing")
if generation is None:
payload = self.repository.list_session_payload(session_id)
if payload.get("generation_metadata") is not None:
generation = GenerationResult(
question=payload["question"],
model_name=payload["model_name"],
answer=payload["answer"],
raw_generation_text=payload.get("raw_generation_text", ""),
raw_trace_text=payload["raw_trace_text"],
normalized_trace_text=payload["normalized_trace_text"],
generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]),
)
result = analyze_generation_result(
question=request.question,
generation=generation,
model_name=request.model_name or self.settings.model_name,
take_log=request.take_log,
max_sentences=request.max_sentences,
max_trace_tokens=request.max_trace_tokens,
validate_top_k=request.validate_top_k,
device_preference=request.device_preference or self.settings.device_preference,
dtype_preference=request.dtype_preference or self.settings.dtype_preference,
attn_implementation=request.attn_implementation or self.settings.attn_implementation,
trust_remote_code=(
self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code
),
low_cpu_mem_usage=(
self.settings.low_cpu_mem_usage
if request.low_cpu_mem_usage is None
else request.low_cpu_mem_usage
),
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
)
self.repository.save_analysis_result(session_id, result)
self.repository.update_status(session_id, status="completed")
except Exception as exc:
self.repository.update_status(session_id, status="failed", error=str(exc))