MarisUK's picture
Maris AI model sync
f440f03 verified
"""Teksta ģenerēšana ar Maris AI modeli."""
from __future__ import annotations
import asyncio
import copy
import inspect
import json
import logging
import math
import re
import threading
from contextlib import suppress
from dataclasses import dataclass
from queue import SimpleQueue
from time import monotonic, perf_counter
from typing import Any, Literal
from uuid import uuid4
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, Field, field_validator
from maris_core.memory_context import memory_store
from maris_core.orchestrator.routing import build_system_prompt, resolve_text_model
from maris_core.personas import resolve_persona
from maris_core.text.tools import (
MAX_TOOL_STEPS_CAP,
MAX_TOOL_STEPS_DEFAULT,
ToolTrace,
build_tool_context_message,
execute_tool_trace,
plan_tool_use,
)
from maris_core.training.hf_compat import maris_hf_compatible_path
from maris_core.utils.emotional_context import analyze_emotional_context
from maris_core.utils.env import get_env_any, validate_hf_model
from maris_core.utils.hf_inference import create_hf_inference_client
from maris_core.utils.hf_integration import HFIntegration
logger = logging.getLogger(__name__)
router = APIRouter()
TEXT_GENERATE_MESSAGE_MAX_CHARS = 12_000
TEXT_GENERATE_HISTORY_LIMIT = 32
DEFAULT_MAX_NEW_TOKENS = 2_048
DEFAULT_TEMPERATURE = 0.7
FALLBACK_MODEL_NAME = "MarisUK/maris-assistant-runtime-fallback"
RUNTIME_FALLBACK_MODEL_ENV = "MARIS_RUNTIME_FALLBACK_MODEL"
HF_RUNTIME_FALLBACK_MODEL_ENV = "HF_RUNTIME_FALLBACK_MODEL"
STREAM_CHUNK_CHARS = 48
STREAM_QUEUE_POLL_SECONDS = 0.1
PIPELINE_RETRY_COOLDOWN_SECONDS = 30.0
# Approximate conversion from regex fragments to model tokens for fallback usage estimates.
TOKEN_ESTIMATION_EXPANSION_RATIO = 0.75
_TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE)
class ConversationMessage(BaseModel):
model_config = ConfigDict()
role: Literal["system", "user", "assistant"]
content: str = Field(min_length=1, max_length=TEXT_GENERATE_MESSAGE_MAX_CHARS)
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
if not value.strip():
raise ValueError("content nedrīkst būt tukšs.")
return value
class GenerateRequest(BaseModel):
model_config = ConfigDict()
message: str = Field(min_length=1, max_length=TEXT_GENERATE_MESSAGE_MAX_CHARS)
history: list[ConversationMessage] = Field(
default_factory=list,
max_length=TEXT_GENERATE_HISTORY_LIMIT,
)
max_new_tokens: int = Field(
default=DEFAULT_MAX_NEW_TOKENS,
ge=32,
)
temperature: float = Field(default=DEFAULT_TEMPERATURE, ge=0.0, le=1.5)
max_tool_steps: int = Field(default=MAX_TOOL_STEPS_DEFAULT, ge=1, le=MAX_TOOL_STEPS_CAP)
profile: str | None = Field(default=None, max_length=40)
persona_id: str | None = Field(default=None, max_length=64)
session_id: str | None = Field(default=None, max_length=120)
fallback_model: str | None = Field(default=None, max_length=160)
vision_context: dict[str, Any] | None = None
@field_validator("message")
@classmethod
def validate_message(cls, value: str) -> str:
if not value.strip():
raise ValueError("message nedrīkst būt tukšs.")
return value
@field_validator("profile", "persona_id", "session_id")
@classmethod
def normalize_optional_text(cls, value: str | None) -> str | None:
normalized = (value or "").strip()
return normalized or None
@field_validator("fallback_model")
@classmethod
def validate_fallback_model(cls, value: str | None) -> str | None:
normalized = (value or "").strip()
if not normalized:
return None
try:
return validate_hf_model(normalized, "fallback_model")
except RuntimeError as exc:
raise ValueError(str(exc)) from exc
class GenerateResponse(BaseModel):
response: str
model: str
tokens_used: int = 0
detected_emotion: str = "neutral"
emotion_confidence: float = 0.0
response_style: str = "clear_grounded"
persona_id: str = "assistant"
persona_title: str = "Core Assistant"
persona_summary: str = ""
request_id: str
session_id: str
latency_ms: int = 0
prompt_messages: int = 0
memory_matches: int = 0
tool_trace: ToolTrace | None = None
@dataclass(slots=True)
class GenerationContext:
req: GenerateRequest
request_id: str
started_at: float
hf: HFIntegration
emotional_context: Any
persona: Any
session_id: str
history: list[dict[str, str]]
memory_matches: list[Any]
user_focus: list[str]
active_threads: list[str]
vision_context: dict[str, Any]
session_summary: list[str]
messages: list[dict[str, str]]
tool_trace: ToolTrace | None
def _build_pipeline() -> Any:
"""Izveido Maris AI pipeline."""
global _pipeline_last_error # noqa: PLW0603
global _pipeline_runtime_model # noqa: PLW0603
global _pipeline_compatibility_restore_active # noqa: PLW0603
try:
from transformers import pipeline # type: ignore
model_name = resolve_text_model()
logger.info("Ielādē teksta modeli: %s", model_name)
with maris_hf_compatible_path(model_name) as compatible_model_path:
compatibility_restore_active = compatible_model_path != model_name
if compatibility_restore_active:
logger.info(
"Runtime loaderis aktivizēja compatibility restore modelim %s.",
model_name,
)
with _pipeline_lock:
_pipeline_runtime_model = model_name
_pipeline_compatibility_restore_active = compatibility_restore_active
return pipeline(
"text-generation",
model=compatible_model_path,
device_map="auto",
trust_remote_code=True,
)
except Exception as exc: # noqa: BLE001
with _pipeline_lock:
if not _pipeline_runtime_model:
with suppress(Exception):
_pipeline_runtime_model = resolve_text_model()
_pipeline_last_error = str(exc)
logger.error("Nevar ielādēt Maris AI teksta modeli: %s", exc)
return None
_pipeline: Any = None
_pipeline_loading = False
_pipeline_last_failure_at = 0.0
_pipeline_last_error: str | None = None
_pipeline_runtime_model = ""
_pipeline_compatibility_restore_active = False
_pipeline_lock = threading.Lock()
def _load_pipeline_in_background() -> None:
global _pipeline # noqa: PLW0603
global _pipeline_loading # noqa: PLW0603
global _pipeline_last_failure_at # noqa: PLW0603
global _pipeline_last_error # noqa: PLW0603
pipeline = _build_pipeline()
with _pipeline_lock:
_pipeline = pipeline
_pipeline_loading = False
_pipeline_last_failure_at = monotonic() if pipeline is None else 0.0
if pipeline is not None:
_pipeline_last_error = None
elif not _pipeline_last_error:
_pipeline_last_error = "pipeline_initialization_failed"
def _should_retry_pipeline_load(now: float) -> bool:
return (
_pipeline_last_failure_at <= 0
or now - _pipeline_last_failure_at >= PIPELINE_RETRY_COOLDOWN_SECONDS
)
def _retry_after_seconds(now: float) -> int:
remaining = PIPELINE_RETRY_COOLDOWN_SECONDS - (now - _pipeline_last_failure_at)
return max(1, math.ceil(remaining))
def _start_pipeline_load_locked() -> None:
global _pipeline_loading # noqa: PLW0603
_pipeline_loading = True
logger.info("Sākam teksta modeļa ielādi fonā; pirmās atbildes izmantos fallback režīmu.")
loader = threading.Thread(
target=_load_pipeline_in_background,
name="maris-text-pipeline-loader",
daemon=True,
)
loader.start()
def _resolve_pipeline_readiness_locked(now: float, *, start_loading: bool) -> dict[str, Any]:
model_name = _pipeline_runtime_model or _safe_runtime_model_name()
payload: dict[str, Any] = {
"model": model_name,
"compatibility_restore_active": _pipeline_compatibility_restore_active,
}
if _pipeline_last_error:
payload["last_error"] = _pipeline_last_error
if _pipeline is not None:
return {"ready": True, "state": "ready", **payload}
if _pipeline_loading:
return {"ready": False, "state": "warming_up", **payload}
if not _should_retry_pipeline_load(now):
retry_after_seconds = _retry_after_seconds(now)
return {
"ready": False,
"state": "retry_cooldown",
"retry_after_seconds": retry_after_seconds,
**payload,
}
if start_loading:
_start_pipeline_load_locked()
return {"ready": False, "state": "warming_up", **payload}
return {"ready": False, "state": "cold", **payload}
def _safe_runtime_model_name() -> str:
try:
return resolve_text_model()
except Exception: # noqa: BLE001
return ""
def get_text_model_readiness(*, start_loading: bool = False) -> dict[str, Any]:
now = monotonic()
with _pipeline_lock:
return _resolve_pipeline_readiness_locked(now, start_loading=start_loading)
def get_pipeline() -> Any:
global _pipeline # noqa: PLW0603
with _pipeline_lock:
readiness = _resolve_pipeline_readiness_locked(monotonic(), start_loading=True)
if readiness["ready"]:
return _pipeline
return None
def warm_text_model_runtime() -> None:
get_text_model_readiness(start_loading=True)
def _build_generation_config(
pipe: Any,
*,
max_new_tokens: int,
temperature: float,
) -> Any | None:
try:
from transformers import GenerationConfig # type: ignore
except ImportError:
return None
base_config = getattr(pipe, "generation_config", None)
model = getattr(pipe, "model", None)
if base_config is None:
base_config = getattr(model, "generation_config", None)
if base_config is None:
model_config = getattr(model, "config", None)
config = (
GenerationConfig.from_model_config(model_config)
if model_config is not None
else GenerationConfig()
)
else:
config = copy.deepcopy(base_config)
config.max_new_tokens = max_new_tokens
# Setting max_length to None leaves max_new_tokens as the sole output-length limit.
config.max_length = None
config.do_sample = temperature > 0
if temperature > 0:
config.temperature = temperature
return config
def call_generation_pipeline(
pipe: Any,
inputs: Any,
*,
max_new_tokens: int,
temperature: float,
return_full_text: bool | None = None,
) -> Any:
generation_config = _build_generation_config(
pipe,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
if generation_config is not None and _supports_generation_config(pipe):
generation_kwargs: dict[str, Any] = {"generation_config": generation_config}
if return_full_text is not None:
generation_kwargs["return_full_text"] = return_full_text
try:
return pipe(inputs, **generation_kwargs)
except TypeError as exc:
message = str(exc)
if "generation_config" not in message and "unexpected keyword" not in message:
raise
logger.debug(
"Pipeline rejected generation_config, retrying with explicit kwargs: %s", exc
)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
}
if return_full_text is not None:
generation_kwargs["return_full_text"] = return_full_text
return pipe(inputs, **generation_kwargs)
def _supports_generation_config(pipe: Any) -> bool:
try:
signature = inspect.signature(pipe)
except (TypeError, ValueError):
return True
if "generation_config" in signature.parameters:
return True
return any(
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
)
def _resolve_session_id(raw_session_id: str | None, request_id: str) -> str:
normalized = (raw_session_id or "").strip()
if normalized:
return normalized
return f"ephemeral-{request_id}"
def _normalize_history(history: list[ConversationMessage]) -> list[dict[str, str]]:
return [item.model_dump() for item in history]
async def _run_pipeline(
pipe: Any,
messages: list[dict[str, str]],
*,
max_new_tokens: int,
temperature: float,
) -> Any:
try:
return await asyncio.to_thread(
call_generation_pipeline,
pipe,
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
except (AttributeError, TypeError, ValueError) as exc:
prompt = _render_generation_prompt(pipe, messages)
logger.info(
"Falling back to prompt-based text generation after chat payload failure: %s",
exc,
)
return await asyncio.to_thread(
call_generation_pipeline,
pipe,
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
return_full_text=False,
)
def _render_manual_prompt(messages: list[dict[str, str]]) -> str:
lines: list[str] = []
for message in messages:
raw_role = message.get("role", "user")
role = raw_role.strip().lower() if isinstance(raw_role, str) else "user"
if not role:
role = "user"
content = message.get("content", "").strip()
if not content:
continue
lines.append(f"{role.title()}: {content}")
lines.append("Assistant:")
return "\n".join(lines)
def _render_generation_prompt(pipe: Any | None, messages: list[dict[str, str]]) -> str:
if pipe is not None:
tokenizer = getattr(pipe, "tokenizer", None)
apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
else:
apply_chat_template = None
if callable(apply_chat_template):
try:
rendered = apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except (AttributeError, TypeError, ValueError) as exc:
logger.debug("Chat template rendering failed, using manual prompt fallback: %s", exc)
rendered = None
if isinstance(rendered, str) and rendered.strip():
return rendered
return _render_manual_prompt(messages)
def _resolve_fallback_model(raw_fallback_model: str | None) -> str | None:
requested = (raw_fallback_model or "").strip()
if requested:
return requested
configured = get_env_any(
RUNTIME_FALLBACK_MODEL_ENV,
HF_RUNTIME_FALLBACK_MODEL_ENV,
)
return (configured or "").strip() or None
def _extract_inference_response_text(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if isinstance(value, list):
for item in reversed(value):
text = _extract_inference_response_text(item)
if text:
return text
return ""
if isinstance(value, dict):
choices = value.get("choices")
if isinstance(choices, list):
text = _extract_inference_response_text(choices)
if text:
return text
for key in ("generated_text", "text", "response", "content", "answer"):
if key in value:
text = _extract_inference_response_text(value[key])
if text:
return text
if "message" in value:
text = _extract_inference_response_text(value["message"])
if text:
return text
choices = getattr(value, "choices", None)
if choices is not None:
text = _extract_inference_response_text(choices)
if text:
return text
message = getattr(value, "message", None)
if message is not None:
text = _extract_inference_response_text(message)
if text:
return text
content = getattr(value, "content", None)
if content is not None:
return _extract_inference_response_text(content)
text = getattr(value, "text", None)
if text is not None:
return _extract_inference_response_text(text)
return ""
def complete_with_hf_fallback(
messages: list[dict[str, str]],
*,
fallback_model: str | None,
max_new_tokens: int,
temperature: float,
) -> tuple[str, str] | None:
resolved_model = _resolve_fallback_model(fallback_model)
if not resolved_model:
return None
try:
from huggingface_hub import InferenceClient # type: ignore
from huggingface_hub.utils import HfHubHTTPError # type: ignore
except ImportError:
return None
retryable_hf_errors: tuple[type[Exception], ...] = (
OSError,
TypeError,
ValueError,
RuntimeError,
HfHubHTTPError,
)
client = create_hf_inference_client(InferenceClient)
try:
raw_response = client.chat_completion(
model=resolved_model,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
)
response_text = _sanitize_response_text(
_extract_inference_response_text(raw_response), messages
)
if response_text:
return resolved_model, response_text
except AttributeError:
pass
except retryable_hf_errors as exc:
logger.warning("HF fallback chat_completion failed for model %s: %s", resolved_model, exc)
prompt = _render_generation_prompt(None, messages)
try:
raw_response = client.text_generation(
prompt=prompt,
model=resolved_model,
max_new_tokens=max_new_tokens,
temperature=temperature,
return_full_text=False,
)
except AttributeError:
return None
except retryable_hf_errors as exc:
logger.warning("HF fallback text_generation failed for model %s: %s", resolved_model, exc)
return None
response_text = _sanitize_response_text(
_extract_inference_response_text(raw_response), messages
)
if not response_text:
return None
return resolved_model, response_text
def _chunk_text_for_streaming(text: str) -> list[str]:
chunks: list[str] = []
current: list[str] = []
last_boundary: int | None = None
for char in text:
current.append(char)
if char.isspace():
last_boundary = len(current)
if len(current) >= STREAM_CHUNK_CHARS:
if last_boundary is not None and last_boundary < len(current):
chunks.append("".join(current[:last_boundary]))
current = current[last_boundary:]
last_boundary = None
else:
chunks.append("".join(current))
current = []
last_boundary = None
if current:
chunks.append("".join(current))
return chunks
def _encode_stream_event(event_name: str, payload: Any) -> str:
return f"event: {event_name}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
def _prepare_stream_model_inputs(
pipe: Any, messages: list[dict[str, str]]
) -> tuple[Any, Any, dict[str, Any]]:
tokenizer = getattr(pipe, "tokenizer", None)
model = getattr(pipe, "model", None)
if tokenizer is None or model is None:
raise TypeError("Teksta pipeline neatbalsta tiešo streaming ģenerēšanu.")
prompt = _render_generation_prompt(pipe, messages)
encoded = tokenizer(prompt, return_tensors="pt")
if not isinstance(encoded, dict):
encoded = dict(encoded)
model_device = getattr(model, "device", None)
if model_device is not None:
encoded = {
key: value.to(model_device) if hasattr(value, "to") else value
for key, value in encoded.items()
}
return tokenizer, model, encoded
async def _prepare_generation_context(
req: GenerateRequest,
*,
request_id: str,
started_at: float,
) -> GenerationContext:
hf = HFIntegration()
emotional_context = analyze_emotional_context(req.message)
persona = resolve_persona(req.persona_id)
session_id = _resolve_session_id(req.session_id, request_id)
history = _normalize_history(req.history)
memory_store.seed_history(session_id, history)
memory_matches = memory_store.retrieve_relevant_context(session_id, req.message)
user_focus = memory_store.summarize_user_focus(session_id, query=req.message)
active_threads = memory_store.summarize_active_threads(session_id, query=req.message)
vision_context = req.vision_context or {}
session_summary = memory_store.summarize_session(session_id)
planned_tool_trace = plan_tool_use(req.message)
tool_trace = (
await execute_tool_trace(
planned_tool_trace,
message=req.message,
max_steps=req.max_tool_steps,
)
if planned_tool_trace is not None
else None
)
messages = _build_generation_messages(
req,
persona=persona,
emotional_context=emotional_context,
history=history,
session_summary=session_summary,
memory_matches=memory_matches,
user_focus=user_focus,
active_threads=active_threads,
vision_context=vision_context,
tool_trace=tool_trace,
)
logger.info(
"Teksta ģenerēšana sākta request_id=%s session_id=%s persona=%s profile=%s history=%s memory=%s",
request_id,
session_id,
persona.id,
req.profile or "general",
len(history),
len(memory_matches),
)
return GenerationContext(
req=req,
request_id=request_id,
started_at=started_at,
hf=hf,
emotional_context=emotional_context,
persona=persona,
session_id=session_id,
history=history,
memory_matches=memory_matches,
user_focus=user_focus,
active_threads=active_threads,
vision_context=vision_context,
session_summary=session_summary,
messages=messages,
tool_trace=tool_trace,
)
async def _persist_generation_result(
context: GenerationContext,
*,
response_text: str,
tokens_used: int,
latency_ms: int,
fallback_used: bool,
) -> None:
memory_store.remember_message(context.session_id, "user", context.req.message)
if context.vision_context.get("summary"):
memory_store.remember_message(
context.session_id,
"assistant",
str(context.vision_context["summary"]),
source="vision_context",
)
memory_store.remember_message(context.session_id, "assistant", response_text)
await context.hf.save_conversation(
context.req.message,
response_text,
metadata={
"request_id": context.request_id,
"session_id": context.session_id,
"persona_id": context.persona.id,
"profile": context.req.profile or "general",
"detected_emotion": context.emotional_context.emotion,
"response_style": context.emotional_context.response_style,
"tokens_used": tokens_used,
"latency_ms": latency_ms,
"history_messages": len(context.history),
"memory_matches": len(context.memory_matches),
"user_focus_items": len(context.user_focus),
"active_thread_items": len(context.active_threads),
"vision_context": bool(context.vision_context.get("summary")),
"fallback_used": fallback_used,
"session_summary_items": len(context.session_summary),
"tool_mode": context.tool_trace.mode if context.tool_trace is not None else "direct",
"tool_steps": len(context.tool_trace.steps) if context.tool_trace is not None else 0,
"tool_grounding_sources": (
len(context.tool_trace.grounding_sources) if context.tool_trace is not None else 0
),
"requested_fallback_model": context.req.fallback_model,
},
)
def _build_generate_response(
context: GenerationContext,
*,
response_text: str,
model_name: str,
tokens_used: int,
latency_ms: int,
) -> GenerateResponse:
logger.info(
"Teksta ģenerēšana pabeigta request_id=%s session_id=%s latency_ms=%s tokens_used=%s",
context.request_id,
context.session_id,
latency_ms,
tokens_used,
)
return GenerateResponse(
response=response_text,
model=model_name,
tokens_used=tokens_used,
detected_emotion=context.emotional_context.emotion,
emotion_confidence=context.emotional_context.confidence,
response_style=context.emotional_context.response_style,
persona_id=context.persona.id,
persona_title=context.persona.title,
persona_summary=context.persona.summary,
request_id=context.request_id,
session_id=context.session_id,
latency_ms=latency_ms,
prompt_messages=len(context.messages),
memory_matches=len(context.memory_matches),
tool_trace=context.tool_trace,
)
def _extract_response_text(payload: Any) -> str:
candidate = payload
if isinstance(payload, list) and payload:
candidate = payload[0]
text = _extract_text_candidate(candidate)
if text:
return text
raise ValueError("Neizdevās nolasīt ģenerēto atbildi no modeļa rezultāta.")
def _sanitize_response_text(response_text: str, messages: list[dict[str, str]]) -> str:
cleaned = response_text.strip()
if not cleaned:
return ""
manual_prompt = _render_manual_prompt(messages)
if manual_prompt and manual_prompt in cleaned:
trimmed = cleaned.rsplit(manual_prompt, maxsplit=1)[-1].strip()
if trimmed:
cleaned = trimmed
for marker in ("Assistant:", "assistant:", "Maris:", "maris:", "<|assistant|>", "<assistant>"):
while cleaned.startswith(marker):
cleaned = cleaned[len(marker) :].lstrip()
return cleaned.strip()
def _extract_text_candidate(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if isinstance(value, list):
for item in reversed(value):
text = _extract_text_candidate(item)
if text:
return text
return ""
if isinstance(value, dict):
for key in ("generated_text", "text", "response", "content", "answer"):
if key in value:
text = _extract_text_candidate(value[key])
if text:
return text
if "message" in value:
text = _extract_text_candidate(value["message"])
if text:
return text
return ""
def _extract_usage_tokens(payload: Any) -> int | None:
if isinstance(payload, dict):
usage = payload.get("usage")
if isinstance(usage, dict):
for key in ("total_tokens", "generated_tokens", "completion_tokens"):
token_value = usage.get(key)
if isinstance(token_value, int) and token_value >= 0:
return token_value
for key in ("total_tokens", "generated_tokens", "token_count", "tokens_used"):
token_value = payload.get(key)
if isinstance(token_value, int) and token_value >= 0:
return token_value
for nested_key in ("generated_text", "message", "details"):
if nested_key in payload:
tokens = _extract_usage_tokens(payload[nested_key])
if tokens is not None:
return tokens
if isinstance(payload, list):
for item in payload:
tokens = _extract_usage_tokens(item)
if tokens is not None:
return tokens
return None
def _estimate_tokens_used(messages: list[dict[str, str]], response_text: str) -> int:
"""Return a rough token estimate when the runtime does not expose real token usage."""
prompt_text = "\n".join(item.get("content", "") for item in messages)
prompt_tokens = max(
1,
math.ceil(len(_TOKEN_PATTERN.findall(prompt_text)) / TOKEN_ESTIMATION_EXPANSION_RATIO),
)
response_tokens = max(
1,
math.ceil(len(_TOKEN_PATTERN.findall(response_text)) / TOKEN_ESTIMATION_EXPANSION_RATIO),
)
return prompt_tokens + response_tokens
def _build_assistant_runtime_contract(
*,
user_message: str,
profile: str | None,
persona_title: str,
persona_summary: str,
communication_style: str,
session_summary: list[str],
memory_matches: list[Any],
user_focus: list[str],
active_threads: list[str],
vision_context: dict[str, Any],
tool_trace: ToolTrace | None,
) -> str:
instructions = [
"Assistant runtime contract:",
"- Uzturi vienotu Maris identitāti visā atbildē.",
f"- Turpini aktīvo personu konsekventi: {persona_title}{persona_summary}",
f"- Runā stilā: {communication_style}.",
"- Balsti atbildi lietotāja ziņā, vēsturē, atmiņā un vizuālajā kontekstā, ja tāds ir dots.",
"- Ja informācijas nepietiek, skaidri nosauc trūkstošo un uzdod vienu precizējošu jautājumu, nevis izdomā faktus.",
"- Iekšējo domāšanu nerādi; lietotājam dod tikai secinājumus, pamatojumu un nākamo soli.",
"- Ja pieprasījums ir riskants vai faktus vajag pārbaudīt ārpus dotā konteksta, pasaki to tieši.",
"- Ja atmiņa, tool grounding vai vizuālais konteksts konfliktē ar jaunāko lietotāja ziņu, prioritizē jaunāko ziņu un skaidri nosauc konfliktu.",
"- Neatsaucies uz ārējiem avotiem, tool rezultātiem vai sesijas faktiem, ja tie tiešām nav dotajā kontekstā.",
]
if session_summary:
instructions.append("- Sesijas mugurkauls, kas jāsaglabā:")
instructions.extend(f" {item}" for item in session_summary)
if user_focus:
instructions.append("- Lietotāja ilgtermiņa fokuss, preferences un mērķi:")
instructions.extend(f" {item}" for item in user_focus)
if active_threads:
instructions.append("- Aktīvie pavedieni, kas jāturpina bez konteksta zaudēšanas:")
instructions.extend(f" {item}" for item in active_threads)
if memory_matches:
instructions.append("- Prioritizē atmiņas faktus, kas tieši atkārtojas šajā pieprasījumā.")
if vision_context.get("summary"):
instructions.append(
"- Piesien secinājumus pie vizuālā novērojuma, nevis tikai vispārīgām frāzēm."
)
if tool_trace is not None and tool_trace.steps:
instructions.append(
"- Ja ir tool grounding rezultāti, izmanto tos kā primāro pamatojuma avotu un nenoklusē nenoteiktību."
)
if _is_coding_request(user_message, profile):
instructions.extend(
[
"Coding response contract:",
"- Dod tieši izpildāmu vai ļoti precīzi pielāgojamu risinājumu, nevis miglainu pseidokodu.",
"- Pirms vai pēc koda īsi nosauc pieņēmumus, ja prasībā kaut kas nav pilnīgi skaidrs.",
"- Pievieno būtiskos edge cases, validāciju, kļūdu apstrādi un drošības apsvērumus, ja tie ir svarīgi.",
"- Ja lietotājs prasa izmaiņas esošā kodā, saglabā risinājumu minimālu un saderīgu ar esošo kontekstu.",
"- Ja dod kodu, piemini kā to pārbaudīt ar testu, manuālu pārbaudi vai izsaukuma piemēru.",
]
)
return "\n".join(instructions)
def _is_coding_request(message: str, profile: str | None) -> bool:
normalized_profile = (profile or "").strip().lower()
if normalized_profile == "coder":
return True
lowered = message.lower()
return any(
token in lowered
for token in (
"kod",
"python",
"rust",
"typescript",
"javascript",
"sql",
"regex",
"api",
"bug",
"test",
"refactor",
"stack trace",
)
)
def _build_memory_summary_message(session_summary: list[str]) -> str | None:
if not session_summary:
return None
return "Sesijas kopsavilkums ilgākai konsekvencei:\n" + "\n".join(
f"- {item}" for item in session_summary
)
def _build_memory_match_message(memory_matches: list[Any]) -> str | None:
if not memory_matches:
return None
memory_lines = [
f"- {match.role} [{match.source}; {match.score:.2f}]: {match.content}"
for match in memory_matches
]
return "Saistītā atmiņa no iepriekšējām sesijām vai sarunas vēstures:\n" + "\n".join(
memory_lines
)
def _build_user_focus_message(user_focus: list[str]) -> str | None:
if not user_focus:
return None
return "Lietotāja ilgtermiņa fokuss šai sesijai:\n" + "\n".join(
f"- {item}" for item in user_focus
)
def _build_active_threads_message(active_threads: list[str]) -> str | None:
if not active_threads:
return None
return "Aktīvie pavedieni šai sesijai:\n" + "\n".join(f"- {item}" for item in active_threads)
def _build_vision_message(vision_context: dict[str, Any]) -> str | None:
if not vision_context.get("summary"):
return None
return (
"Vizuālais konteksts no attēla vai kameras šai pašai sesijai:\n"
f"- avots: {vision_context.get('source', 'unknown')}\n"
f"- modelis: {vision_context.get('model', 'unknown')}\n"
f"- kopsavilkums: {vision_context['summary']}\n"
f"- detekcijas: {vision_context.get('detections', 0)}\n"
f"- izmērs: {vision_context.get('width', 0)}x{vision_context.get('height', 0)}"
)
def _build_graceful_fallback_response(
req: GenerateRequest,
*,
persona_title: str,
session_summary: list[str],
memory_matches: list[Any],
user_focus: list[str],
active_threads: list[str],
vision_context: dict[str, Any],
tool_trace: ToolTrace | None,
) -> str:
opening = "Pilnais modelis šobrīd nav pieejams, tāpēc dodu drošu rezerves atbildi no pieejamā konteksta."
lines = [
opening,
f"Aktīvā persona: {persona_title}.",
f"Tavs pieprasījums: {req.message.strip()}.",
]
if session_summary:
lines.append("Ko saglabāju no iepriekšējās sarunas:")
lines.extend(f"- {item}" for item in session_summary[:3])
if user_focus:
lines.append("Ko saprotu par taviem mērķiem un preferencēm:")
lines.extend(f"- {item}" for item in user_focus[:3])
if active_threads:
lines.append("Ko šobrīd uztveru kā aktīvos pavedienus:")
lines.extend(f"- {item}" for item in active_threads[:3])
elif memory_matches:
lines.append("Saistītais iepriekšējais konteksts:")
lines.extend(f"- {match.content}" for match in memory_matches[:2])
if vision_context.get("summary"):
lines.append(f"Vizuāli redzamais konteksts: {vision_context['summary']}.")
if tool_trace is not None and tool_trace.steps:
lines.append("Papildu grounding no rīkiem:")
lines.extend(f"- {step.summary}" for step in tool_trace.steps[:3])
lines.append(
"Nākamais labākais solis: pasaki, vai gribi īsu kopsavilkumu, detalizētu plānu vai konkrētu izpildes soli."
)
return "\n".join(lines)
def _build_generation_messages(
req: GenerateRequest,
*,
persona: Any,
emotional_context: Any,
history: list[dict[str, str]],
session_summary: list[str],
memory_matches: list[Any],
user_focus: list[str],
active_threads: list[str],
vision_context: dict[str, Any],
tool_trace: ToolTrace | None,
) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": build_system_prompt(
req.profile,
emotional_context,
persona_id=persona.id,
),
},
{
"role": "system",
"content": _build_assistant_runtime_contract(
user_message=req.message,
profile=req.profile,
persona_title=persona.title,
persona_summary=persona.summary,
communication_style=persona.communication_style,
session_summary=session_summary,
memory_matches=memory_matches,
user_focus=user_focus,
active_threads=active_threads,
vision_context=vision_context,
tool_trace=tool_trace,
),
},
]
memory_summary_message = _build_memory_summary_message(session_summary)
if memory_summary_message:
messages.append({"role": "system", "content": memory_summary_message})
memory_match_message = _build_memory_match_message(memory_matches)
if memory_match_message:
messages.append({"role": "system", "content": memory_match_message})
user_focus_message = _build_user_focus_message(user_focus)
if user_focus_message:
messages.append({"role": "system", "content": user_focus_message})
active_threads_message = _build_active_threads_message(active_threads)
if active_threads_message:
messages.append({"role": "system", "content": active_threads_message})
vision_message = _build_vision_message(vision_context)
if vision_message:
messages.append({"role": "system", "content": vision_message})
tool_message = build_tool_context_message(tool_trace) if tool_trace is not None else None
if tool_message:
messages.append({"role": "system", "content": tool_message})
messages.extend(history)
messages.append({"role": "user", "content": req.message})
return messages
async def _generate_runtime_response(
req: GenerateRequest,
*,
messages: list[dict[str, str]],
session_summary: list[str],
memory_matches: list[Any],
persona: Any,
user_focus: list[str],
active_threads: list[str],
vision_context: dict[str, Any],
tool_trace: ToolTrace | None,
request_id: str,
) -> tuple[str, str, int, bool]:
pipe = get_pipeline()
if pipe is None:
fallback_result = complete_with_hf_fallback(
messages,
fallback_model=req.fallback_model,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
)
if fallback_result is not None:
model_name, response_text = fallback_result
return response_text, model_name, _estimate_tokens_used(messages, response_text), True
logger.warning(
"Teksta modelis nav pieejams request_id=%s; izmanto rezerves atbildi.", request_id
)
response_text = _build_graceful_fallback_response(
req,
persona_title=persona.title,
session_summary=session_summary,
memory_matches=memory_matches,
user_focus=user_focus,
active_threads=active_threads,
vision_context=vision_context,
tool_trace=tool_trace,
)
return (
response_text,
FALLBACK_MODEL_NAME,
_estimate_tokens_used(messages, response_text),
True,
)
try:
raw_output = await _run_pipeline(
pipe,
messages,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
)
response_text = _sanitize_response_text(_extract_response_text(raw_output), messages)
tokens_used = _extract_usage_tokens(raw_output) or _estimate_tokens_used(
messages, response_text
)
return response_text, resolve_text_model(), tokens_used, False
except ValueError as exc:
logger.error(
"Maris AI pipeline atgrieza nederīgu formātu request_id=%s: %s", request_id, exc
)
except Exception as exc: # noqa: BLE001
logger.error("Maris AI pipeline kļūda request_id=%s: %s", request_id, exc)
fallback_result = complete_with_hf_fallback(
messages,
fallback_model=req.fallback_model,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
)
if fallback_result is not None:
model_name, response_text = fallback_result
return response_text, model_name, _estimate_tokens_used(messages, response_text), True
response_text = _build_graceful_fallback_response(
req,
persona_title=persona.title,
session_summary=session_summary,
memory_matches=memory_matches,
user_focus=user_focus,
active_threads=active_threads,
vision_context=vision_context,
tool_trace=tool_trace,
)
return response_text, FALLBACK_MODEL_NAME, _estimate_tokens_used(messages, response_text), True
@router.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest) -> GenerateResponse:
"""Ģenerē tekstu ar Qwen3 modeli."""
request_id = uuid4().hex[:12]
started_at = perf_counter()
context = await _prepare_generation_context(req, request_id=request_id, started_at=started_at)
response_text, model_name, tokens_used, fallback_used = await _generate_runtime_response(
req,
messages=context.messages,
session_summary=context.session_summary,
memory_matches=context.memory_matches,
persona=context.persona,
user_focus=context.user_focus,
active_threads=context.active_threads,
vision_context=context.vision_context,
tool_trace=context.tool_trace,
request_id=context.request_id,
)
latency_ms = int((perf_counter() - started_at) * 1000)
await _persist_generation_result(
context,
response_text=response_text,
tokens_used=tokens_used,
latency_ms=latency_ms,
fallback_used=fallback_used,
)
return _build_generate_response(
context,
response_text=response_text,
model_name=model_name,
tokens_used=tokens_used,
latency_ms=latency_ms,
)
@router.post("/generate/stream")
async def generate_stream(req: GenerateRequest, request: Request) -> StreamingResponse:
"""Straumē teksta ģenerēšanu ar īstiem modeļa delta tokeniem."""
request_id = uuid4().hex[:12]
started_at = perf_counter()
context = await _prepare_generation_context(req, request_id=request_id, started_at=started_at)
async def event_stream() -> Any:
response_parts: list[str] = []
pipe = get_pipeline()
async def emit_complete(
response_text: str,
*,
model_name: str,
tokens_used: int,
fallback_used: bool,
) -> str:
latency_ms = int((perf_counter() - started_at) * 1000)
await _persist_generation_result(
context,
response_text=response_text,
tokens_used=tokens_used,
latency_ms=latency_ms,
fallback_used=fallback_used,
)
payload = _build_generate_response(
context,
response_text=response_text,
model_name=model_name,
tokens_used=tokens_used,
latency_ms=latency_ms,
)
return _encode_stream_event("complete", payload.model_dump())
async def emit_fallback(reason: Exception | None = None) -> Any:
if reason is not None:
logger.error(
"Straumētā teksta ģenerēšana pārgāja uz fallback request_id=%s: %s",
context.request_id,
reason,
)
fallback_result = complete_with_hf_fallback(
context.messages,
fallback_model=context.req.fallback_model,
max_new_tokens=context.req.max_new_tokens,
temperature=context.req.temperature,
)
if fallback_result is not None:
model_name, response_text = fallback_result
else:
model_name = FALLBACK_MODEL_NAME
response_text = _build_graceful_fallback_response(
context.req,
persona_title=context.persona.title,
session_summary=context.session_summary,
memory_matches=context.memory_matches,
user_focus=context.user_focus,
active_threads=context.active_threads,
vision_context=context.vision_context,
tool_trace=context.tool_trace,
)
for delta in _chunk_text_for_streaming(response_text):
if await request.is_disconnected():
return
yield _encode_stream_event("delta", {"delta": delta})
if await request.is_disconnected():
return
yield await emit_complete(
response_text,
model_name=model_name,
tokens_used=_estimate_tokens_used(context.messages, response_text),
fallback_used=model_name != resolve_text_model(),
)
if pipe is None:
async for item in emit_fallback():
yield item
return
try:
from transformers import ( # type: ignore
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
tokenizer, model, model_inputs = _prepare_stream_model_inputs(pipe, context.messages)
stop_event = threading.Event()
queue: asyncio.Queue[str | None] = asyncio.Queue()
errors: SimpleQueue[BaseException] = SimpleQueue()
loop = asyncio.get_running_loop()
class CancelOnDisconnect(StoppingCriteria):
def __call__(self, input_ids: Any, scores: Any, **kwargs: Any) -> bool:
return stop_event.is_set()
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs: dict[str, Any] = {
**model_inputs,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([CancelOnDisconnect()]),
}
generation_config = _build_generation_config(
pipe,
max_new_tokens=context.req.max_new_tokens,
temperature=context.req.temperature,
)
if generation_config is not None:
generation_kwargs["generation_config"] = generation_config
else:
generation_kwargs["max_new_tokens"] = context.req.max_new_tokens
generation_kwargs["do_sample"] = context.req.temperature > 0
if context.req.temperature > 0:
generation_kwargs["temperature"] = context.req.temperature
eos_token_id = getattr(tokenizer, "eos_token_id", None)
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = eos_token_id
if generation_config is not None:
if eos_token_id is not None:
generation_config.eos_token_id = eos_token_id
if pad_token_id is not None:
generation_config.pad_token_id = pad_token_id
else:
if eos_token_id is not None:
generation_kwargs["eos_token_id"] = eos_token_id
if pad_token_id is not None:
generation_kwargs["pad_token_id"] = pad_token_id
def run_generation() -> None:
try:
model.generate(**generation_kwargs)
except BaseException as exc: # noqa: BLE001
errors.put(exc)
def pump_streamer() -> None:
try:
for chunk in streamer:
loop.call_soon_threadsafe(queue.put_nowait, chunk)
except BaseException as exc: # noqa: BLE001
errors.put(exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
generation_thread = threading.Thread(target=run_generation, daemon=True)
pump_thread = threading.Thread(target=pump_streamer, daemon=True)
generation_thread.start()
pump_thread.start()
cancelled = False
try:
while True:
if await request.is_disconnected():
cancelled = True
stop_event.set()
break
try:
item = await asyncio.wait_for(
queue.get(),
timeout=STREAM_QUEUE_POLL_SECONDS,
)
except TimeoutError:
continue
if item is None:
break
if not item:
continue
response_parts.append(item)
yield _encode_stream_event("delta", {"delta": item})
finally:
stop_event.set()
await asyncio.to_thread(generation_thread.join, 1.0)
await asyncio.to_thread(pump_thread.join, 1.0)
if cancelled:
logger.info(
"Straumētā teksta ģenerēšana atcelta request_id=%s session_id=%s",
context.request_id,
context.session_id,
)
return
response_text = "".join(response_parts).strip()
first_error = errors.get_nowait() if not errors.empty() else None
if first_error is not None and not response_text:
raise RuntimeError(str(first_error))
if not response_text:
raise ValueError("Straumētā ģenerēšana neatgrieza nevienu tokenu.")
yield await emit_complete(
response_text,
model_name=resolve_text_model(),
tokens_used=_estimate_tokens_used(context.messages, response_text),
fallback_used=False,
)
except Exception as exc: # noqa: BLE001
async for item in emit_fallback(exc):
yield item
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)