"""Teksta ģenerēšana ar Maris AI modeli.""" from __future__ import annotations import asyncio import copy import inspect import json import logging import math import re import threading from contextlib import suppress from dataclasses import dataclass from queue import SimpleQueue from time import monotonic, perf_counter from typing import Any, Literal from uuid import uuid4 from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, Field, field_validator from maris_core.memory_context import memory_store from maris_core.orchestrator.routing import build_system_prompt, resolve_text_model from maris_core.personas import resolve_persona from maris_core.text.tools import ( MAX_TOOL_STEPS_CAP, MAX_TOOL_STEPS_DEFAULT, ToolTrace, build_tool_context_message, execute_tool_trace, plan_tool_use, ) from maris_core.training.hf_compat import maris_hf_compatible_path from maris_core.utils.emotional_context import analyze_emotional_context from maris_core.utils.env import get_env_any, validate_hf_model from maris_core.utils.hf_inference import create_hf_inference_client from maris_core.utils.hf_integration import HFIntegration logger = logging.getLogger(__name__) router = APIRouter() TEXT_GENERATE_MESSAGE_MAX_CHARS = 12_000 TEXT_GENERATE_HISTORY_LIMIT = 32 DEFAULT_MAX_NEW_TOKENS = 2_048 DEFAULT_TEMPERATURE = 0.7 FALLBACK_MODEL_NAME = "MarisUK/maris-assistant-runtime-fallback" RUNTIME_FALLBACK_MODEL_ENV = "MARIS_RUNTIME_FALLBACK_MODEL" HF_RUNTIME_FALLBACK_MODEL_ENV = "HF_RUNTIME_FALLBACK_MODEL" STREAM_CHUNK_CHARS = 48 STREAM_QUEUE_POLL_SECONDS = 0.1 PIPELINE_RETRY_COOLDOWN_SECONDS = 30.0 # Approximate conversion from regex fragments to model tokens for fallback usage estimates. TOKEN_ESTIMATION_EXPANSION_RATIO = 0.75 _TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE) class ConversationMessage(BaseModel): model_config = ConfigDict() role: Literal["system", "user", "assistant"] content: str = Field(min_length=1, max_length=TEXT_GENERATE_MESSAGE_MAX_CHARS) @field_validator("content") @classmethod def validate_content(cls, value: str) -> str: if not value.strip(): raise ValueError("content nedrīkst būt tukšs.") return value class GenerateRequest(BaseModel): model_config = ConfigDict() message: str = Field(min_length=1, max_length=TEXT_GENERATE_MESSAGE_MAX_CHARS) history: list[ConversationMessage] = Field( default_factory=list, max_length=TEXT_GENERATE_HISTORY_LIMIT, ) max_new_tokens: int = Field( default=DEFAULT_MAX_NEW_TOKENS, ge=32, ) temperature: float = Field(default=DEFAULT_TEMPERATURE, ge=0.0, le=1.5) max_tool_steps: int = Field(default=MAX_TOOL_STEPS_DEFAULT, ge=1, le=MAX_TOOL_STEPS_CAP) profile: str | None = Field(default=None, max_length=40) persona_id: str | None = Field(default=None, max_length=64) session_id: str | None = Field(default=None, max_length=120) fallback_model: str | None = Field(default=None, max_length=160) vision_context: dict[str, Any] | None = None @field_validator("message") @classmethod def validate_message(cls, value: str) -> str: if not value.strip(): raise ValueError("message nedrīkst būt tukšs.") return value @field_validator("profile", "persona_id", "session_id") @classmethod def normalize_optional_text(cls, value: str | None) -> str | None: normalized = (value or "").strip() return normalized or None @field_validator("fallback_model") @classmethod def validate_fallback_model(cls, value: str | None) -> str | None: normalized = (value or "").strip() if not normalized: return None try: return validate_hf_model(normalized, "fallback_model") except RuntimeError as exc: raise ValueError(str(exc)) from exc class GenerateResponse(BaseModel): response: str model: str tokens_used: int = 0 detected_emotion: str = "neutral" emotion_confidence: float = 0.0 response_style: str = "clear_grounded" persona_id: str = "assistant" persona_title: str = "Core Assistant" persona_summary: str = "" request_id: str session_id: str latency_ms: int = 0 prompt_messages: int = 0 memory_matches: int = 0 tool_trace: ToolTrace | None = None @dataclass(slots=True) class GenerationContext: req: GenerateRequest request_id: str started_at: float hf: HFIntegration emotional_context: Any persona: Any session_id: str history: list[dict[str, str]] memory_matches: list[Any] user_focus: list[str] active_threads: list[str] vision_context: dict[str, Any] session_summary: list[str] messages: list[dict[str, str]] tool_trace: ToolTrace | None def _build_pipeline() -> Any: """Izveido Maris AI pipeline.""" global _pipeline_last_error # noqa: PLW0603 global _pipeline_runtime_model # noqa: PLW0603 global _pipeline_compatibility_restore_active # noqa: PLW0603 try: from transformers import pipeline # type: ignore model_name = resolve_text_model() logger.info("Ielādē teksta modeli: %s", model_name) with maris_hf_compatible_path(model_name) as compatible_model_path: compatibility_restore_active = compatible_model_path != model_name if compatibility_restore_active: logger.info( "Runtime loaderis aktivizēja compatibility restore modelim %s.", model_name, ) with _pipeline_lock: _pipeline_runtime_model = model_name _pipeline_compatibility_restore_active = compatibility_restore_active return pipeline( "text-generation", model=compatible_model_path, device_map="auto", trust_remote_code=True, ) except Exception as exc: # noqa: BLE001 with _pipeline_lock: if not _pipeline_runtime_model: with suppress(Exception): _pipeline_runtime_model = resolve_text_model() _pipeline_last_error = str(exc) logger.error("Nevar ielādēt Maris AI teksta modeli: %s", exc) return None _pipeline: Any = None _pipeline_loading = False _pipeline_last_failure_at = 0.0 _pipeline_last_error: str | None = None _pipeline_runtime_model = "" _pipeline_compatibility_restore_active = False _pipeline_lock = threading.Lock() def _load_pipeline_in_background() -> None: global _pipeline # noqa: PLW0603 global _pipeline_loading # noqa: PLW0603 global _pipeline_last_failure_at # noqa: PLW0603 global _pipeline_last_error # noqa: PLW0603 pipeline = _build_pipeline() with _pipeline_lock: _pipeline = pipeline _pipeline_loading = False _pipeline_last_failure_at = monotonic() if pipeline is None else 0.0 if pipeline is not None: _pipeline_last_error = None elif not _pipeline_last_error: _pipeline_last_error = "pipeline_initialization_failed" def _should_retry_pipeline_load(now: float) -> bool: return ( _pipeline_last_failure_at <= 0 or now - _pipeline_last_failure_at >= PIPELINE_RETRY_COOLDOWN_SECONDS ) def _retry_after_seconds(now: float) -> int: remaining = PIPELINE_RETRY_COOLDOWN_SECONDS - (now - _pipeline_last_failure_at) return max(1, math.ceil(remaining)) def _start_pipeline_load_locked() -> None: global _pipeline_loading # noqa: PLW0603 _pipeline_loading = True logger.info("Sākam teksta modeļa ielādi fonā; pirmās atbildes izmantos fallback režīmu.") loader = threading.Thread( target=_load_pipeline_in_background, name="maris-text-pipeline-loader", daemon=True, ) loader.start() def _resolve_pipeline_readiness_locked(now: float, *, start_loading: bool) -> dict[str, Any]: model_name = _pipeline_runtime_model or _safe_runtime_model_name() payload: dict[str, Any] = { "model": model_name, "compatibility_restore_active": _pipeline_compatibility_restore_active, } if _pipeline_last_error: payload["last_error"] = _pipeline_last_error if _pipeline is not None: return {"ready": True, "state": "ready", **payload} if _pipeline_loading: return {"ready": False, "state": "warming_up", **payload} if not _should_retry_pipeline_load(now): retry_after_seconds = _retry_after_seconds(now) return { "ready": False, "state": "retry_cooldown", "retry_after_seconds": retry_after_seconds, **payload, } if start_loading: _start_pipeline_load_locked() return {"ready": False, "state": "warming_up", **payload} return {"ready": False, "state": "cold", **payload} def _safe_runtime_model_name() -> str: try: return resolve_text_model() except Exception: # noqa: BLE001 return "" def get_text_model_readiness(*, start_loading: bool = False) -> dict[str, Any]: now = monotonic() with _pipeline_lock: return _resolve_pipeline_readiness_locked(now, start_loading=start_loading) def get_pipeline() -> Any: global _pipeline # noqa: PLW0603 with _pipeline_lock: readiness = _resolve_pipeline_readiness_locked(monotonic(), start_loading=True) if readiness["ready"]: return _pipeline return None def warm_text_model_runtime() -> None: get_text_model_readiness(start_loading=True) def _build_generation_config( pipe: Any, *, max_new_tokens: int, temperature: float, ) -> Any | None: try: from transformers import GenerationConfig # type: ignore except ImportError: return None base_config = getattr(pipe, "generation_config", None) model = getattr(pipe, "model", None) if base_config is None: base_config = getattr(model, "generation_config", None) if base_config is None: model_config = getattr(model, "config", None) config = ( GenerationConfig.from_model_config(model_config) if model_config is not None else GenerationConfig() ) else: config = copy.deepcopy(base_config) config.max_new_tokens = max_new_tokens # Setting max_length to None leaves max_new_tokens as the sole output-length limit. config.max_length = None config.do_sample = temperature > 0 if temperature > 0: config.temperature = temperature return config def call_generation_pipeline( pipe: Any, inputs: Any, *, max_new_tokens: int, temperature: float, return_full_text: bool | None = None, ) -> Any: generation_config = _build_generation_config( pipe, max_new_tokens=max_new_tokens, temperature=temperature, ) if generation_config is not None and _supports_generation_config(pipe): generation_kwargs: dict[str, Any] = {"generation_config": generation_config} if return_full_text is not None: generation_kwargs["return_full_text"] = return_full_text try: return pipe(inputs, **generation_kwargs) except TypeError as exc: message = str(exc) if "generation_config" not in message and "unexpected keyword" not in message: raise logger.debug( "Pipeline rejected generation_config, retrying with explicit kwargs: %s", exc ) generation_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature, } if return_full_text is not None: generation_kwargs["return_full_text"] = return_full_text return pipe(inputs, **generation_kwargs) def _supports_generation_config(pipe: Any) -> bool: try: signature = inspect.signature(pipe) except (TypeError, ValueError): return True if "generation_config" in signature.parameters: return True return any( param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() ) def _resolve_session_id(raw_session_id: str | None, request_id: str) -> str: normalized = (raw_session_id or "").strip() if normalized: return normalized return f"ephemeral-{request_id}" def _normalize_history(history: list[ConversationMessage]) -> list[dict[str, str]]: return [item.model_dump() for item in history] async def _run_pipeline( pipe: Any, messages: list[dict[str, str]], *, max_new_tokens: int, temperature: float, ) -> Any: try: return await asyncio.to_thread( call_generation_pipeline, pipe, messages, max_new_tokens=max_new_tokens, temperature=temperature, ) except (AttributeError, TypeError, ValueError) as exc: prompt = _render_generation_prompt(pipe, messages) logger.info( "Falling back to prompt-based text generation after chat payload failure: %s", exc, ) return await asyncio.to_thread( call_generation_pipeline, pipe, prompt, max_new_tokens=max_new_tokens, temperature=temperature, return_full_text=False, ) def _render_manual_prompt(messages: list[dict[str, str]]) -> str: lines: list[str] = [] for message in messages: raw_role = message.get("role", "user") role = raw_role.strip().lower() if isinstance(raw_role, str) else "user" if not role: role = "user" content = message.get("content", "").strip() if not content: continue lines.append(f"{role.title()}: {content}") lines.append("Assistant:") return "\n".join(lines) def _render_generation_prompt(pipe: Any | None, messages: list[dict[str, str]]) -> str: if pipe is not None: tokenizer = getattr(pipe, "tokenizer", None) apply_chat_template = getattr(tokenizer, "apply_chat_template", None) else: apply_chat_template = None if callable(apply_chat_template): try: rendered = apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except (AttributeError, TypeError, ValueError) as exc: logger.debug("Chat template rendering failed, using manual prompt fallback: %s", exc) rendered = None if isinstance(rendered, str) and rendered.strip(): return rendered return _render_manual_prompt(messages) def _resolve_fallback_model(raw_fallback_model: str | None) -> str | None: requested = (raw_fallback_model or "").strip() if requested: return requested configured = get_env_any( RUNTIME_FALLBACK_MODEL_ENV, HF_RUNTIME_FALLBACK_MODEL_ENV, ) return (configured or "").strip() or None def _extract_inference_response_text(value: Any) -> str: if isinstance(value, str): return value.strip() if isinstance(value, list): for item in reversed(value): text = _extract_inference_response_text(item) if text: return text return "" if isinstance(value, dict): choices = value.get("choices") if isinstance(choices, list): text = _extract_inference_response_text(choices) if text: return text for key in ("generated_text", "text", "response", "content", "answer"): if key in value: text = _extract_inference_response_text(value[key]) if text: return text if "message" in value: text = _extract_inference_response_text(value["message"]) if text: return text choices = getattr(value, "choices", None) if choices is not None: text = _extract_inference_response_text(choices) if text: return text message = getattr(value, "message", None) if message is not None: text = _extract_inference_response_text(message) if text: return text content = getattr(value, "content", None) if content is not None: return _extract_inference_response_text(content) text = getattr(value, "text", None) if text is not None: return _extract_inference_response_text(text) return "" def complete_with_hf_fallback( messages: list[dict[str, str]], *, fallback_model: str | None, max_new_tokens: int, temperature: float, ) -> tuple[str, str] | None: resolved_model = _resolve_fallback_model(fallback_model) if not resolved_model: return None try: from huggingface_hub import InferenceClient # type: ignore from huggingface_hub.utils import HfHubHTTPError # type: ignore except ImportError: return None retryable_hf_errors: tuple[type[Exception], ...] = ( OSError, TypeError, ValueError, RuntimeError, HfHubHTTPError, ) client = create_hf_inference_client(InferenceClient) try: raw_response = client.chat_completion( model=resolved_model, messages=messages, max_tokens=max_new_tokens, temperature=temperature, ) response_text = _sanitize_response_text( _extract_inference_response_text(raw_response), messages ) if response_text: return resolved_model, response_text except AttributeError: pass except retryable_hf_errors as exc: logger.warning("HF fallback chat_completion failed for model %s: %s", resolved_model, exc) prompt = _render_generation_prompt(None, messages) try: raw_response = client.text_generation( prompt=prompt, model=resolved_model, max_new_tokens=max_new_tokens, temperature=temperature, return_full_text=False, ) except AttributeError: return None except retryable_hf_errors as exc: logger.warning("HF fallback text_generation failed for model %s: %s", resolved_model, exc) return None response_text = _sanitize_response_text( _extract_inference_response_text(raw_response), messages ) if not response_text: return None return resolved_model, response_text def _chunk_text_for_streaming(text: str) -> list[str]: chunks: list[str] = [] current: list[str] = [] last_boundary: int | None = None for char in text: current.append(char) if char.isspace(): last_boundary = len(current) if len(current) >= STREAM_CHUNK_CHARS: if last_boundary is not None and last_boundary < len(current): chunks.append("".join(current[:last_boundary])) current = current[last_boundary:] last_boundary = None else: chunks.append("".join(current)) current = [] last_boundary = None if current: chunks.append("".join(current)) return chunks def _encode_stream_event(event_name: str, payload: Any) -> str: return f"event: {event_name}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" def _prepare_stream_model_inputs( pipe: Any, messages: list[dict[str, str]] ) -> tuple[Any, Any, dict[str, Any]]: tokenizer = getattr(pipe, "tokenizer", None) model = getattr(pipe, "model", None) if tokenizer is None or model is None: raise TypeError("Teksta pipeline neatbalsta tiešo streaming ģenerēšanu.") prompt = _render_generation_prompt(pipe, messages) encoded = tokenizer(prompt, return_tensors="pt") if not isinstance(encoded, dict): encoded = dict(encoded) model_device = getattr(model, "device", None) if model_device is not None: encoded = { key: value.to(model_device) if hasattr(value, "to") else value for key, value in encoded.items() } return tokenizer, model, encoded async def _prepare_generation_context( req: GenerateRequest, *, request_id: str, started_at: float, ) -> GenerationContext: hf = HFIntegration() emotional_context = analyze_emotional_context(req.message) persona = resolve_persona(req.persona_id) session_id = _resolve_session_id(req.session_id, request_id) history = _normalize_history(req.history) memory_store.seed_history(session_id, history) memory_matches = memory_store.retrieve_relevant_context(session_id, req.message) user_focus = memory_store.summarize_user_focus(session_id, query=req.message) active_threads = memory_store.summarize_active_threads(session_id, query=req.message) vision_context = req.vision_context or {} session_summary = memory_store.summarize_session(session_id) planned_tool_trace = plan_tool_use(req.message) tool_trace = ( await execute_tool_trace( planned_tool_trace, message=req.message, max_steps=req.max_tool_steps, ) if planned_tool_trace is not None else None ) messages = _build_generation_messages( req, persona=persona, emotional_context=emotional_context, history=history, session_summary=session_summary, memory_matches=memory_matches, user_focus=user_focus, active_threads=active_threads, vision_context=vision_context, tool_trace=tool_trace, ) logger.info( "Teksta ģenerēšana sākta request_id=%s session_id=%s persona=%s profile=%s history=%s memory=%s", request_id, session_id, persona.id, req.profile or "general", len(history), len(memory_matches), ) return GenerationContext( req=req, request_id=request_id, started_at=started_at, hf=hf, emotional_context=emotional_context, persona=persona, session_id=session_id, history=history, memory_matches=memory_matches, user_focus=user_focus, active_threads=active_threads, vision_context=vision_context, session_summary=session_summary, messages=messages, tool_trace=tool_trace, ) async def _persist_generation_result( context: GenerationContext, *, response_text: str, tokens_used: int, latency_ms: int, fallback_used: bool, ) -> None: memory_store.remember_message(context.session_id, "user", context.req.message) if context.vision_context.get("summary"): memory_store.remember_message( context.session_id, "assistant", str(context.vision_context["summary"]), source="vision_context", ) memory_store.remember_message(context.session_id, "assistant", response_text) await context.hf.save_conversation( context.req.message, response_text, metadata={ "request_id": context.request_id, "session_id": context.session_id, "persona_id": context.persona.id, "profile": context.req.profile or "general", "detected_emotion": context.emotional_context.emotion, "response_style": context.emotional_context.response_style, "tokens_used": tokens_used, "latency_ms": latency_ms, "history_messages": len(context.history), "memory_matches": len(context.memory_matches), "user_focus_items": len(context.user_focus), "active_thread_items": len(context.active_threads), "vision_context": bool(context.vision_context.get("summary")), "fallback_used": fallback_used, "session_summary_items": len(context.session_summary), "tool_mode": context.tool_trace.mode if context.tool_trace is not None else "direct", "tool_steps": len(context.tool_trace.steps) if context.tool_trace is not None else 0, "tool_grounding_sources": ( len(context.tool_trace.grounding_sources) if context.tool_trace is not None else 0 ), "requested_fallback_model": context.req.fallback_model, }, ) def _build_generate_response( context: GenerationContext, *, response_text: str, model_name: str, tokens_used: int, latency_ms: int, ) -> GenerateResponse: logger.info( "Teksta ģenerēšana pabeigta request_id=%s session_id=%s latency_ms=%s tokens_used=%s", context.request_id, context.session_id, latency_ms, tokens_used, ) return GenerateResponse( response=response_text, model=model_name, tokens_used=tokens_used, detected_emotion=context.emotional_context.emotion, emotion_confidence=context.emotional_context.confidence, response_style=context.emotional_context.response_style, persona_id=context.persona.id, persona_title=context.persona.title, persona_summary=context.persona.summary, request_id=context.request_id, session_id=context.session_id, latency_ms=latency_ms, prompt_messages=len(context.messages), memory_matches=len(context.memory_matches), tool_trace=context.tool_trace, ) def _extract_response_text(payload: Any) -> str: candidate = payload if isinstance(payload, list) and payload: candidate = payload[0] text = _extract_text_candidate(candidate) if text: return text raise ValueError("Neizdevās nolasīt ģenerēto atbildi no modeļa rezultāta.") def _sanitize_response_text(response_text: str, messages: list[dict[str, str]]) -> str: cleaned = response_text.strip() if not cleaned: return "" manual_prompt = _render_manual_prompt(messages) if manual_prompt and manual_prompt in cleaned: trimmed = cleaned.rsplit(manual_prompt, maxsplit=1)[-1].strip() if trimmed: cleaned = trimmed for marker in ("Assistant:", "assistant:", "Maris:", "maris:", "<|assistant|>", ""): while cleaned.startswith(marker): cleaned = cleaned[len(marker) :].lstrip() return cleaned.strip() def _extract_text_candidate(value: Any) -> str: if isinstance(value, str): return value.strip() if isinstance(value, list): for item in reversed(value): text = _extract_text_candidate(item) if text: return text return "" if isinstance(value, dict): for key in ("generated_text", "text", "response", "content", "answer"): if key in value: text = _extract_text_candidate(value[key]) if text: return text if "message" in value: text = _extract_text_candidate(value["message"]) if text: return text return "" def _extract_usage_tokens(payload: Any) -> int | None: if isinstance(payload, dict): usage = payload.get("usage") if isinstance(usage, dict): for key in ("total_tokens", "generated_tokens", "completion_tokens"): token_value = usage.get(key) if isinstance(token_value, int) and token_value >= 0: return token_value for key in ("total_tokens", "generated_tokens", "token_count", "tokens_used"): token_value = payload.get(key) if isinstance(token_value, int) and token_value >= 0: return token_value for nested_key in ("generated_text", "message", "details"): if nested_key in payload: tokens = _extract_usage_tokens(payload[nested_key]) if tokens is not None: return tokens if isinstance(payload, list): for item in payload: tokens = _extract_usage_tokens(item) if tokens is not None: return tokens return None def _estimate_tokens_used(messages: list[dict[str, str]], response_text: str) -> int: """Return a rough token estimate when the runtime does not expose real token usage.""" prompt_text = "\n".join(item.get("content", "") for item in messages) prompt_tokens = max( 1, math.ceil(len(_TOKEN_PATTERN.findall(prompt_text)) / TOKEN_ESTIMATION_EXPANSION_RATIO), ) response_tokens = max( 1, math.ceil(len(_TOKEN_PATTERN.findall(response_text)) / TOKEN_ESTIMATION_EXPANSION_RATIO), ) return prompt_tokens + response_tokens def _build_assistant_runtime_contract( *, user_message: str, profile: str | None, persona_title: str, persona_summary: str, communication_style: str, session_summary: list[str], memory_matches: list[Any], user_focus: list[str], active_threads: list[str], vision_context: dict[str, Any], tool_trace: ToolTrace | None, ) -> str: instructions = [ "Assistant runtime contract:", "- Uzturi vienotu Maris identitāti visā atbildē.", f"- Turpini aktīvo personu konsekventi: {persona_title} — {persona_summary}", f"- Runā stilā: {communication_style}.", "- Balsti atbildi lietotāja ziņā, vēsturē, atmiņā un vizuālajā kontekstā, ja tāds ir dots.", "- Ja informācijas nepietiek, skaidri nosauc trūkstošo un uzdod vienu precizējošu jautājumu, nevis izdomā faktus.", "- Iekšējo domāšanu nerādi; lietotājam dod tikai secinājumus, pamatojumu un nākamo soli.", "- Ja pieprasījums ir riskants vai faktus vajag pārbaudīt ārpus dotā konteksta, pasaki to tieši.", "- Ja atmiņa, tool grounding vai vizuālais konteksts konfliktē ar jaunāko lietotāja ziņu, prioritizē jaunāko ziņu un skaidri nosauc konfliktu.", "- Neatsaucies uz ārējiem avotiem, tool rezultātiem vai sesijas faktiem, ja tie tiešām nav dotajā kontekstā.", ] if session_summary: instructions.append("- Sesijas mugurkauls, kas jāsaglabā:") instructions.extend(f" {item}" for item in session_summary) if user_focus: instructions.append("- Lietotāja ilgtermiņa fokuss, preferences un mērķi:") instructions.extend(f" {item}" for item in user_focus) if active_threads: instructions.append("- Aktīvie pavedieni, kas jāturpina bez konteksta zaudēšanas:") instructions.extend(f" {item}" for item in active_threads) if memory_matches: instructions.append("- Prioritizē atmiņas faktus, kas tieši atkārtojas šajā pieprasījumā.") if vision_context.get("summary"): instructions.append( "- Piesien secinājumus pie vizuālā novērojuma, nevis tikai vispārīgām frāzēm." ) if tool_trace is not None and tool_trace.steps: instructions.append( "- Ja ir tool grounding rezultāti, izmanto tos kā primāro pamatojuma avotu un nenoklusē nenoteiktību." ) if _is_coding_request(user_message, profile): instructions.extend( [ "Coding response contract:", "- Dod tieši izpildāmu vai ļoti precīzi pielāgojamu risinājumu, nevis miglainu pseidokodu.", "- Pirms vai pēc koda īsi nosauc pieņēmumus, ja prasībā kaut kas nav pilnīgi skaidrs.", "- Pievieno būtiskos edge cases, validāciju, kļūdu apstrādi un drošības apsvērumus, ja tie ir svarīgi.", "- Ja lietotājs prasa izmaiņas esošā kodā, saglabā risinājumu minimālu un saderīgu ar esošo kontekstu.", "- Ja dod kodu, piemini kā to pārbaudīt ar testu, manuālu pārbaudi vai izsaukuma piemēru.", ] ) return "\n".join(instructions) def _is_coding_request(message: str, profile: str | None) -> bool: normalized_profile = (profile or "").strip().lower() if normalized_profile == "coder": return True lowered = message.lower() return any( token in lowered for token in ( "kod", "python", "rust", "typescript", "javascript", "sql", "regex", "api", "bug", "test", "refactor", "stack trace", ) ) def _build_memory_summary_message(session_summary: list[str]) -> str | None: if not session_summary: return None return "Sesijas kopsavilkums ilgākai konsekvencei:\n" + "\n".join( f"- {item}" for item in session_summary ) def _build_memory_match_message(memory_matches: list[Any]) -> str | None: if not memory_matches: return None memory_lines = [ f"- {match.role} [{match.source}; {match.score:.2f}]: {match.content}" for match in memory_matches ] return "Saistītā atmiņa no iepriekšējām sesijām vai sarunas vēstures:\n" + "\n".join( memory_lines ) def _build_user_focus_message(user_focus: list[str]) -> str | None: if not user_focus: return None return "Lietotāja ilgtermiņa fokuss šai sesijai:\n" + "\n".join( f"- {item}" for item in user_focus ) def _build_active_threads_message(active_threads: list[str]) -> str | None: if not active_threads: return None return "Aktīvie pavedieni šai sesijai:\n" + "\n".join(f"- {item}" for item in active_threads) def _build_vision_message(vision_context: dict[str, Any]) -> str | None: if not vision_context.get("summary"): return None return ( "Vizuālais konteksts no attēla vai kameras šai pašai sesijai:\n" f"- avots: {vision_context.get('source', 'unknown')}\n" f"- modelis: {vision_context.get('model', 'unknown')}\n" f"- kopsavilkums: {vision_context['summary']}\n" f"- detekcijas: {vision_context.get('detections', 0)}\n" f"- izmērs: {vision_context.get('width', 0)}x{vision_context.get('height', 0)}" ) def _build_graceful_fallback_response( req: GenerateRequest, *, persona_title: str, session_summary: list[str], memory_matches: list[Any], user_focus: list[str], active_threads: list[str], vision_context: dict[str, Any], tool_trace: ToolTrace | None, ) -> str: opening = "Pilnais modelis šobrīd nav pieejams, tāpēc dodu drošu rezerves atbildi no pieejamā konteksta." lines = [ opening, f"Aktīvā persona: {persona_title}.", f"Tavs pieprasījums: {req.message.strip()}.", ] if session_summary: lines.append("Ko saglabāju no iepriekšējās sarunas:") lines.extend(f"- {item}" for item in session_summary[:3]) if user_focus: lines.append("Ko saprotu par taviem mērķiem un preferencēm:") lines.extend(f"- {item}" for item in user_focus[:3]) if active_threads: lines.append("Ko šobrīd uztveru kā aktīvos pavedienus:") lines.extend(f"- {item}" for item in active_threads[:3]) elif memory_matches: lines.append("Saistītais iepriekšējais konteksts:") lines.extend(f"- {match.content}" for match in memory_matches[:2]) if vision_context.get("summary"): lines.append(f"Vizuāli redzamais konteksts: {vision_context['summary']}.") if tool_trace is not None and tool_trace.steps: lines.append("Papildu grounding no rīkiem:") lines.extend(f"- {step.summary}" for step in tool_trace.steps[:3]) lines.append( "Nākamais labākais solis: pasaki, vai gribi īsu kopsavilkumu, detalizētu plānu vai konkrētu izpildes soli." ) return "\n".join(lines) def _build_generation_messages( req: GenerateRequest, *, persona: Any, emotional_context: Any, history: list[dict[str, str]], session_summary: list[str], memory_matches: list[Any], user_focus: list[str], active_threads: list[str], vision_context: dict[str, Any], tool_trace: ToolTrace | None, ) -> list[dict[str, str]]: messages = [ { "role": "system", "content": build_system_prompt( req.profile, emotional_context, persona_id=persona.id, ), }, { "role": "system", "content": _build_assistant_runtime_contract( user_message=req.message, profile=req.profile, persona_title=persona.title, persona_summary=persona.summary, communication_style=persona.communication_style, session_summary=session_summary, memory_matches=memory_matches, user_focus=user_focus, active_threads=active_threads, vision_context=vision_context, tool_trace=tool_trace, ), }, ] memory_summary_message = _build_memory_summary_message(session_summary) if memory_summary_message: messages.append({"role": "system", "content": memory_summary_message}) memory_match_message = _build_memory_match_message(memory_matches) if memory_match_message: messages.append({"role": "system", "content": memory_match_message}) user_focus_message = _build_user_focus_message(user_focus) if user_focus_message: messages.append({"role": "system", "content": user_focus_message}) active_threads_message = _build_active_threads_message(active_threads) if active_threads_message: messages.append({"role": "system", "content": active_threads_message}) vision_message = _build_vision_message(vision_context) if vision_message: messages.append({"role": "system", "content": vision_message}) tool_message = build_tool_context_message(tool_trace) if tool_trace is not None else None if tool_message: messages.append({"role": "system", "content": tool_message}) messages.extend(history) messages.append({"role": "user", "content": req.message}) return messages async def _generate_runtime_response( req: GenerateRequest, *, messages: list[dict[str, str]], session_summary: list[str], memory_matches: list[Any], persona: Any, user_focus: list[str], active_threads: list[str], vision_context: dict[str, Any], tool_trace: ToolTrace | None, request_id: str, ) -> tuple[str, str, int, bool]: pipe = get_pipeline() if pipe is None: fallback_result = complete_with_hf_fallback( messages, fallback_model=req.fallback_model, max_new_tokens=req.max_new_tokens, temperature=req.temperature, ) if fallback_result is not None: model_name, response_text = fallback_result return response_text, model_name, _estimate_tokens_used(messages, response_text), True logger.warning( "Teksta modelis nav pieejams request_id=%s; izmanto rezerves atbildi.", request_id ) response_text = _build_graceful_fallback_response( req, persona_title=persona.title, session_summary=session_summary, memory_matches=memory_matches, user_focus=user_focus, active_threads=active_threads, vision_context=vision_context, tool_trace=tool_trace, ) return ( response_text, FALLBACK_MODEL_NAME, _estimate_tokens_used(messages, response_text), True, ) try: raw_output = await _run_pipeline( pipe, messages, max_new_tokens=req.max_new_tokens, temperature=req.temperature, ) response_text = _sanitize_response_text(_extract_response_text(raw_output), messages) tokens_used = _extract_usage_tokens(raw_output) or _estimate_tokens_used( messages, response_text ) return response_text, resolve_text_model(), tokens_used, False except ValueError as exc: logger.error( "Maris AI pipeline atgrieza nederīgu formātu request_id=%s: %s", request_id, exc ) except Exception as exc: # noqa: BLE001 logger.error("Maris AI pipeline kļūda request_id=%s: %s", request_id, exc) fallback_result = complete_with_hf_fallback( messages, fallback_model=req.fallback_model, max_new_tokens=req.max_new_tokens, temperature=req.temperature, ) if fallback_result is not None: model_name, response_text = fallback_result return response_text, model_name, _estimate_tokens_used(messages, response_text), True response_text = _build_graceful_fallback_response( req, persona_title=persona.title, session_summary=session_summary, memory_matches=memory_matches, user_focus=user_focus, active_threads=active_threads, vision_context=vision_context, tool_trace=tool_trace, ) return response_text, FALLBACK_MODEL_NAME, _estimate_tokens_used(messages, response_text), True @router.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest) -> GenerateResponse: """Ģenerē tekstu ar Qwen3 modeli.""" request_id = uuid4().hex[:12] started_at = perf_counter() context = await _prepare_generation_context(req, request_id=request_id, started_at=started_at) response_text, model_name, tokens_used, fallback_used = await _generate_runtime_response( req, messages=context.messages, session_summary=context.session_summary, memory_matches=context.memory_matches, persona=context.persona, user_focus=context.user_focus, active_threads=context.active_threads, vision_context=context.vision_context, tool_trace=context.tool_trace, request_id=context.request_id, ) latency_ms = int((perf_counter() - started_at) * 1000) await _persist_generation_result( context, response_text=response_text, tokens_used=tokens_used, latency_ms=latency_ms, fallback_used=fallback_used, ) return _build_generate_response( context, response_text=response_text, model_name=model_name, tokens_used=tokens_used, latency_ms=latency_ms, ) @router.post("/generate/stream") async def generate_stream(req: GenerateRequest, request: Request) -> StreamingResponse: """Straumē teksta ģenerēšanu ar īstiem modeļa delta tokeniem.""" request_id = uuid4().hex[:12] started_at = perf_counter() context = await _prepare_generation_context(req, request_id=request_id, started_at=started_at) async def event_stream() -> Any: response_parts: list[str] = [] pipe = get_pipeline() async def emit_complete( response_text: str, *, model_name: str, tokens_used: int, fallback_used: bool, ) -> str: latency_ms = int((perf_counter() - started_at) * 1000) await _persist_generation_result( context, response_text=response_text, tokens_used=tokens_used, latency_ms=latency_ms, fallback_used=fallback_used, ) payload = _build_generate_response( context, response_text=response_text, model_name=model_name, tokens_used=tokens_used, latency_ms=latency_ms, ) return _encode_stream_event("complete", payload.model_dump()) async def emit_fallback(reason: Exception | None = None) -> Any: if reason is not None: logger.error( "Straumētā teksta ģenerēšana pārgāja uz fallback request_id=%s: %s", context.request_id, reason, ) fallback_result = complete_with_hf_fallback( context.messages, fallback_model=context.req.fallback_model, max_new_tokens=context.req.max_new_tokens, temperature=context.req.temperature, ) if fallback_result is not None: model_name, response_text = fallback_result else: model_name = FALLBACK_MODEL_NAME response_text = _build_graceful_fallback_response( context.req, persona_title=context.persona.title, session_summary=context.session_summary, memory_matches=context.memory_matches, user_focus=context.user_focus, active_threads=context.active_threads, vision_context=context.vision_context, tool_trace=context.tool_trace, ) for delta in _chunk_text_for_streaming(response_text): if await request.is_disconnected(): return yield _encode_stream_event("delta", {"delta": delta}) if await request.is_disconnected(): return yield await emit_complete( response_text, model_name=model_name, tokens_used=_estimate_tokens_used(context.messages, response_text), fallback_used=model_name != resolve_text_model(), ) if pipe is None: async for item in emit_fallback(): yield item return try: from transformers import ( # type: ignore StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) tokenizer, model, model_inputs = _prepare_stream_model_inputs(pipe, context.messages) stop_event = threading.Event() queue: asyncio.Queue[str | None] = asyncio.Queue() errors: SimpleQueue[BaseException] = SimpleQueue() loop = asyncio.get_running_loop() class CancelOnDisconnect(StoppingCriteria): def __call__(self, input_ids: Any, scores: Any, **kwargs: Any) -> bool: return stop_event.is_set() streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs: dict[str, Any] = { **model_inputs, "streamer": streamer, "stopping_criteria": StoppingCriteriaList([CancelOnDisconnect()]), } generation_config = _build_generation_config( pipe, max_new_tokens=context.req.max_new_tokens, temperature=context.req.temperature, ) if generation_config is not None: generation_kwargs["generation_config"] = generation_config else: generation_kwargs["max_new_tokens"] = context.req.max_new_tokens generation_kwargs["do_sample"] = context.req.temperature > 0 if context.req.temperature > 0: generation_kwargs["temperature"] = context.req.temperature eos_token_id = getattr(tokenizer, "eos_token_id", None) pad_token_id = getattr(tokenizer, "pad_token_id", None) if pad_token_id is None: pad_token_id = eos_token_id if generation_config is not None: if eos_token_id is not None: generation_config.eos_token_id = eos_token_id if pad_token_id is not None: generation_config.pad_token_id = pad_token_id else: if eos_token_id is not None: generation_kwargs["eos_token_id"] = eos_token_id if pad_token_id is not None: generation_kwargs["pad_token_id"] = pad_token_id def run_generation() -> None: try: model.generate(**generation_kwargs) except BaseException as exc: # noqa: BLE001 errors.put(exc) def pump_streamer() -> None: try: for chunk in streamer: loop.call_soon_threadsafe(queue.put_nowait, chunk) except BaseException as exc: # noqa: BLE001 errors.put(exc) finally: loop.call_soon_threadsafe(queue.put_nowait, None) generation_thread = threading.Thread(target=run_generation, daemon=True) pump_thread = threading.Thread(target=pump_streamer, daemon=True) generation_thread.start() pump_thread.start() cancelled = False try: while True: if await request.is_disconnected(): cancelled = True stop_event.set() break try: item = await asyncio.wait_for( queue.get(), timeout=STREAM_QUEUE_POLL_SECONDS, ) except TimeoutError: continue if item is None: break if not item: continue response_parts.append(item) yield _encode_stream_event("delta", {"delta": item}) finally: stop_event.set() await asyncio.to_thread(generation_thread.join, 1.0) await asyncio.to_thread(pump_thread.join, 1.0) if cancelled: logger.info( "Straumētā teksta ģenerēšana atcelta request_id=%s session_id=%s", context.request_id, context.session_id, ) return response_text = "".join(response_parts).strip() first_error = errors.get_nowait() if not errors.empty() else None if first_error is not None and not response_text: raise RuntimeError(str(first_error)) if not response_text: raise ValueError("Straumētā ģenerēšana neatgrieza nevienu tokenu.") yield await emit_complete( response_text, model_name=resolve_text_model(), tokens_used=_estimate_tokens_used(context.messages, response_text), fallback_used=False, ) except Exception as exc: # noqa: BLE001 async for item in emit_fallback(exc): yield item return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )