Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from app.analysis.sentence_split import split_sentences | |
| from app.core.config import Settings | |
| from app.core.runtime import load_model_bundle | |
| from app.core.runtime_pipeline import analyze_generation_result | |
| from app.core.schemas import AnalysisRequest, AnalysisResult, GenerationMetadata, GenerationResult | |
| from app.generation.service import generate_answer_and_trace | |
| from app.storage.repository import SessionRecord, SessionRepository | |
| from app.workers.jobs import JobRunner | |
| class SessionLimitError(RuntimeError): | |
| pass | |
| class SessionAccessError(PermissionError): | |
| pass | |
| class SessionService: | |
| settings: Settings | |
| repository: SessionRepository | |
| jobs: JobRunner | |
| def create_session(self, request: AnalysisRequest, *, owner_id: str, owner_name: str | None) -> SessionRecord: | |
| if self.repository.count_incomplete_sessions() >= self.settings.max_queued_jobs: | |
| raise SessionLimitError("The service queue is full. Try again after a few minutes.") | |
| if self.repository.count_incomplete_sessions_for_owner(owner_id) >= self.settings.max_active_jobs_per_user: | |
| raise SessionLimitError("You already have the maximum number of active analysis jobs.") | |
| model_name = request.model_name or self.settings.model_name | |
| session = self.repository.create_session( | |
| question=request.question, | |
| model_name=model_name, | |
| owner_id=owner_id, | |
| owner_name=owner_name, | |
| ) | |
| self.jobs.submit(self._run_session_pipeline, session.id, request) | |
| return session | |
| def start_analysis( | |
| self, | |
| session_id: str, | |
| *, | |
| owner_id: str, | |
| request: AnalysisRequest | None = None, | |
| ) -> SessionRecord: | |
| session = self.repository.get_session(session_id) | |
| self._assert_owner(session, owner_id) | |
| effective_request = request or AnalysisRequest(question=session.question, model_name=session.model_name) | |
| self.jobs.submit(self._run_analysis_only, session_id, effective_request) | |
| return session | |
| def get_session_payload(self, session_id: str, *, owner_id: str) -> dict: | |
| session = self.repository.get_session(session_id) | |
| self._assert_owner(session, owner_id) | |
| return self.repository.list_session_payload(session_id) | |
| def list_sessions(self, owner_id: str, *, limit: int = 20) -> list[dict]: | |
| return self.repository.list_sessions_for_owner(owner_id, limit=limit) | |
| def get_analysis_result(self, session_id: str, *, owner_id: str) -> AnalysisResult: | |
| payload = self.get_session_payload(session_id, owner_id=owner_id) | |
| analysis = payload.get("analysis") | |
| if analysis is None: | |
| raise KeyError(session_id) | |
| return AnalysisResult.model_validate(analysis) | |
| def _assert_owner(session: SessionRecord, owner_id: str) -> None: | |
| if session.owner_id != owner_id: | |
| raise SessionAccessError("Session belongs to a different user.") | |
| def _run_session_pipeline(self, session_id: str, request: AnalysisRequest) -> None: | |
| try: | |
| self.repository.update_status(session_id, status="generating") | |
| bundle = load_model_bundle( | |
| request.model_name or self.settings.model_name, | |
| device_preference=request.device_preference or self.settings.device_preference, | |
| dtype_preference=request.dtype_preference or self.settings.dtype_preference, | |
| attn_implementation=request.attn_implementation or self.settings.attn_implementation, | |
| trust_remote_code=( | |
| self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code | |
| ), | |
| low_cpu_mem_usage=( | |
| self.settings.low_cpu_mem_usage | |
| if request.low_cpu_mem_usage is None | |
| else request.low_cpu_mem_usage | |
| ), | |
| ) | |
| generation = generate_answer_and_trace( | |
| question=request.question, | |
| model_name=bundle.model_name, | |
| model=bundle.model, | |
| tokenizer=bundle.tokenizer, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ) | |
| sentences = [span.text for span in split_sentences(generation.normalized_trace_text)] | |
| self.repository.save_generation_result(session_id, generation, sentences) | |
| self.repository.update_status(session_id, status="answer_ready") | |
| self._run_analysis_only(session_id, request, generation=generation) | |
| except Exception as exc: | |
| self.repository.update_status(session_id, status="failed", error=str(exc)) | |
| def _run_analysis_only( | |
| self, | |
| session_id: str, | |
| request: AnalysisRequest, | |
| *, | |
| generation=None, | |
| ) -> None: | |
| try: | |
| self.repository.update_status(session_id, status="analyzing") | |
| if generation is None: | |
| payload = self.repository.list_session_payload(session_id) | |
| if payload.get("generation_metadata") is not None: | |
| generation = GenerationResult( | |
| question=payload["question"], | |
| model_name=payload["model_name"], | |
| answer=payload["answer"], | |
| raw_generation_text=payload.get("raw_generation_text", ""), | |
| raw_trace_text=payload["raw_trace_text"], | |
| normalized_trace_text=payload["normalized_trace_text"], | |
| generation_metadata=GenerationMetadata.model_validate(payload["generation_metadata"]), | |
| ) | |
| result = analyze_generation_result( | |
| question=request.question, | |
| generation=generation, | |
| model_name=request.model_name or self.settings.model_name, | |
| take_log=request.take_log, | |
| max_sentences=request.max_sentences, | |
| max_trace_tokens=request.max_trace_tokens, | |
| validate_top_k=request.validate_top_k, | |
| device_preference=request.device_preference or self.settings.device_preference, | |
| dtype_preference=request.dtype_preference or self.settings.dtype_preference, | |
| attn_implementation=request.attn_implementation or self.settings.attn_implementation, | |
| trust_remote_code=( | |
| self.settings.trust_remote_code if request.trust_remote_code is None else request.trust_remote_code | |
| ), | |
| low_cpu_mem_usage=( | |
| self.settings.low_cpu_mem_usage | |
| if request.low_cpu_mem_usage is None | |
| else request.low_cpu_mem_usage | |
| ), | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ) | |
| self.repository.save_analysis_result(session_id, result) | |
| self.repository.update_status(session_id, status="completed") | |
| except Exception as exc: | |
| self.repository.update_status(session_id, status="failed", error=str(exc)) | |