| """Structured tool-use helpers for main chat text generation.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import html |
| import json |
| import re |
| from html.parser import HTMLParser |
| from pathlib import Path |
| from typing import Any, Literal |
| from urllib.parse import urlparse |
|
|
| import httpx |
| from pydantic import BaseModel, Field |
|
|
| WEB_SEARCH_KEYWORDS = ( |
| "latest", |
| "current", |
| "today", |
| "news", |
| "jaunākais", |
| "aktuāl", |
| "2025", |
| "2026", |
| "release", |
| "pricing", |
| "versij", |
| ) |
| WEB_GROUNDING_KEYWORDS = ( |
| "avot", |
| "source", |
| "cite", |
| "citation", |
| "citē", |
| "citats", |
| "verify", |
| "verif", |
| "pārbaud", |
| "oficiāl", |
| "official", |
| ) |
| WORKSPACE_KEYWORDS = ( |
| "readme", |
| "repo", |
| "repository", |
| "docs", |
| "documentation", |
| "file", |
| "fails", |
| "backend", |
| "frontend", |
| "core-python", |
| "src/", |
| ".py", |
| ".rs", |
| ".ts", |
| ) |
| CODE_GROUNDING_KEYWORDS = ( |
| "debug", |
| "debugging", |
| "bug", |
| "bugfix", |
| "fix", |
| "salabo", |
| "kļūd", |
| "refactor", |
| "refaktor", |
| "diff", |
| "patch", |
| "repo-level", |
| "stack trace", |
| "failing test", |
| "test fails", |
| "unsafe", |
| "nedroš", |
| "large file", |
| "large-file", |
| "existing code", |
| "esoš", |
| ) |
| WORKSPACE_STOPWORDS = { |
| "the", |
| "and", |
| "for", |
| "that", |
| "this", |
| "with", |
| "from", |
| "your", |
| "into", |
| "about", |
| "what", |
| "when", |
| "where", |
| "kā", |
| "kas", |
| "par", |
| "vai", |
| "lai", |
| "uz", |
| "pie", |
| "šajā", |
| "repo", |
| "repository", |
| "failā", |
| "fails", |
| "kodā", |
| "code", |
| "helperi", |
| "helper", |
| "funkciju", |
| "function", |
| "uzraksti", |
| "parādi", |
| "izveido", |
| "starp", |
| } |
| WORKSPACE_EXTENSIONS = {".md", ".txt", ".py", ".rs", ".ts", ".tsx", ".json", ".yml", ".yaml"} |
| MAX_TOOL_STEPS_DEFAULT = 12 |
| MAX_TOOL_STEPS_CAP = 24 |
| DEFAULT_WEB_SEARCH_ENDPOINT = "https://api.duckduckgo.com/" |
| DEFAULT_TOOL_TIMEOUT_SECONDS = 8.0 |
| DEFAULT_WEB_FETCH_MAX_CHARS = 4_000 |
| MAX_GROUNDING_SOURCES = 8 |
| MAX_TOOL_FOLLOW_UPS = 2 |
| MAX_SEARCH_TERMS = 8 |
| MIN_SEARCH_TERM_LENGTH = 2 |
| PATH_MATCH_WEIGHT = 2 |
| MAX_WORKSPACE_SCAN_LINES = 400 |
| HTTP_USER_AGENT = "Maris-MI/1.0" |
| WORKSPACE_ROOT = Path(__file__).resolve().parents[3] |
| URL_STRIP_CHARS = ").,;" |
| PATH_STRIP_CHARS = ").,;:" |
| URL_PATTERN = re.compile(r"https?://[^\s<>\"]+", flags=re.IGNORECASE) |
| HTML_TITLE_PATTERN = re.compile(r"<title[^>]*>(.*?)</title>", flags=re.IGNORECASE | re.DOTALL) |
|
|
|
|
| class GroundingSource(BaseModel): |
| kind: str |
| label: str |
| uri: str | None = None |
| snippet: str | None = None |
| line_start: int | None = None |
|
|
|
|
| class ToolCallRecord(BaseModel): |
| name: str |
| arguments: dict[str, Any] = Field(default_factory=dict) |
| status: Literal["completed", "failed", "skipped"] = "completed" |
| summary: str = "" |
| sources: list[GroundingSource] = Field(default_factory=list) |
|
|
|
|
| class ToolTrace(BaseModel): |
| mode: Literal["direct", "tool_augmented", "multi_step"] = "direct" |
| reasoning: str = "" |
| steps: list[ToolCallRecord] = Field(default_factory=list) |
| grounding_sources: list[GroundingSource] = Field(default_factory=list) |
|
|
|
|
| class PlannedToolCall(BaseModel): |
| name: Literal["web_search", "web_fetch", "workspace_search", "workspace_read"] |
| arguments: dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
| class _HTMLTextExtractor(HTMLParser): |
| def __init__(self) -> None: |
| super().__init__() |
| self._ignored_depth = 0 |
| self._title_depth = 0 |
| self._title_parts: list[str] = [] |
| self._text_parts: list[str] = [] |
|
|
| @property |
| def title(self) -> str: |
| return " ".join("".join(self._title_parts).split()) |
|
|
| @property |
| def text(self) -> str: |
| return " ".join("".join(self._text_parts).split()) |
|
|
| def handle_starttag(self, tag: str, _attrs: list[tuple[str, str | None]]) -> None: |
| normalized = tag.lower() |
| if normalized in {"script", "style"}: |
| self._ignored_depth += 1 |
| elif normalized == "title": |
| self._title_depth += 1 |
|
|
| def handle_endtag(self, tag: str) -> None: |
| normalized = tag.lower() |
| if normalized in {"script", "style"} and self._ignored_depth > 0: |
| self._ignored_depth -= 1 |
| elif normalized == "title" and self._title_depth > 0: |
| self._title_depth -= 1 |
|
|
| def handle_data(self, data: str) -> None: |
| if self._ignored_depth > 0: |
| return |
| if self._title_depth > 0: |
| self._title_parts.append(data) |
| self._text_parts.append(data) |
|
|
|
|
| def plan_tool_use(message: str) -> ToolTrace | None: |
| urls = _extract_urls(message) |
| workspace_candidates = _extract_workspace_path_candidates(message) |
| reasoning: list[str] = [] |
| step_hints = 0 |
|
|
| if urls: |
| step_hints += 1 |
| reasoning.append("pieprasījumā jau ir konkrētas ārējās saites") |
| elif _should_use_web_search(message): |
| step_hints += 1 |
| reasoning.append("pieprasījumā ir aktuālitātes vai ārēja fakta signāli") |
|
|
| if _should_use_workspace_grounding(message, workspace_candidates): |
| step_hints += 1 |
| reasoning.append("pieprasījums izskatās pēc repo/docs/faila jautājuma") |
|
|
| if step_hints == 0: |
| return None |
|
|
| mode: Literal["tool_augmented", "multi_step"] = ( |
| "multi_step" if step_hints > 1 else "tool_augmented" |
| ) |
| return ToolTrace(mode=mode, reasoning=" un ".join(reasoning)) |
|
|
|
|
| async def execute_tool_trace( |
| planned_trace: ToolTrace, |
| *, |
| message: str, |
| workspace_root: Path | None = None, |
| client: httpx.AsyncClient | None = None, |
| max_steps: int | None = None, |
| ) -> ToolTrace: |
| root = (workspace_root or WORKSPACE_ROOT).resolve() |
| http_client = client or httpx.AsyncClient( |
| timeout=DEFAULT_TOOL_TIMEOUT_SECONDS, |
| follow_redirects=True, |
| headers={"User-Agent": HTTP_USER_AGENT}, |
| ) |
| owns_client = client is None |
| steps: list[ToolCallRecord] = [] |
| grounding_sources: list[GroundingSource] = [] |
| limit = _normalize_max_steps(max_steps) |
| pending = _initial_tool_calls(message, root) |
| scheduled: set[str] = set() |
| executed: set[str] = set() |
|
|
| try: |
| while pending and len(steps) < limit: |
| call = pending.pop(0) |
| call_key = _tool_call_key(call) |
| if call_key in executed: |
| continue |
| executed.add(call_key) |
| if call.name == "web_search": |
| record = await _execute_web_search(call.arguments, client=http_client) |
| elif call.name == "web_fetch": |
| record = await _execute_web_fetch(call.arguments, client=http_client) |
| elif call.name == "workspace_search": |
| record = await asyncio.to_thread(_execute_workspace_search, call.arguments, root) |
| else: |
| record = await asyncio.to_thread(_execute_workspace_read, call.arguments, root) |
| steps.append(record) |
| grounding_sources = _merge_grounding_sources(grounding_sources, record.sources) |
| for follow_up in _follow_up_calls(call, record): |
| follow_up_key = _tool_call_key(follow_up) |
| if follow_up_key in executed or follow_up_key in scheduled: |
| continue |
| if len(pending) >= limit: |
| break |
| pending.append(follow_up) |
| scheduled.add(follow_up_key) |
| finally: |
| if owns_client: |
| await http_client.aclose() |
|
|
| return ToolTrace( |
| mode="multi_step" if len(steps) > 1 else planned_trace.mode, |
| reasoning=planned_trace.reasoning, |
| steps=steps, |
| grounding_sources=grounding_sources[:MAX_GROUNDING_SOURCES], |
| ) |
|
|
|
|
| def build_tool_context_message(trace: ToolTrace) -> str | None: |
| if not trace.steps: |
| return None |
| lines = [ |
| "Tool grounding context:", |
| f"- režīms: {trace.mode}", |
| f"- izvēles pamatojums: {trace.reasoning or 'tool use aktivizēts pēc pieprasījuma tipa.'}", |
| ] |
| uncertainty_notes = [ |
| step.summary for step in trace.steps if step.status != "completed" or not step.sources |
| ] |
| for step in trace.steps: |
| lines.append(f"- {step.name} [{step.status}]: {step.summary}") |
| for source in step.sources[:3]: |
| location = f" (line {source.line_start})" if source.line_start else "" |
| uri = f" <{source.uri}>" if source.uri else "" |
| snippet = f" — {source.snippet}" if source.snippet else "" |
| lines.append(f" • {source.label}{location}{uri}{snippet}") |
| if uncertainty_notes: |
| lines.append("- Nenoteiktības signāli, kurus nedrīkst noklusēt:") |
| lines.extend(f" • {note}" for note in uncertainty_notes[:4]) |
| if trace.grounding_sources: |
| lines.append( |
| "- Gala atbildē piesien secinājumus pie konkrētiem avotiem un skaidri nosauc, kas palika nepārbaudīts." |
| ) |
| else: |
| lines.append( |
| "- Rīki neieguva pietiekamu grounding; pasaki, ka secinājumi ir ierobežoti un var prasīt papildu pārbaudi." |
| ) |
| lines.append( |
| "- Gala atbildē balsti secinājumus tikai uz šo kontekstu vai skaidri nosauc nenoteiktību." |
| ) |
| return "\n".join(lines) |
|
|
|
|
| def _initial_tool_calls(message: str, root: Path) -> list[PlannedToolCall]: |
| calls: list[PlannedToolCall] = [] |
| normalized = message.strip() |
| urls = _extract_urls(message) |
| if urls: |
| calls.extend( |
| PlannedToolCall(name="web_fetch", arguments={"url": url}) |
| for url in urls[:MAX_TOOL_FOLLOW_UPS] |
| ) |
| elif _should_use_web_search(message): |
| calls.append(PlannedToolCall(name="web_search", arguments={"query": normalized})) |
|
|
| workspace_paths = _resolve_workspace_candidates( |
| root, _extract_workspace_path_candidates(message) |
| ) |
| calls.extend( |
| PlannedToolCall(name="workspace_read", arguments={"path": str(path), "start_line": 1}) |
| for path in workspace_paths[:MAX_TOOL_FOLLOW_UPS] |
| ) |
| if _should_use_workspace_grounding(message, [str(path) for path in workspace_paths]): |
| calls.append(PlannedToolCall(name="workspace_search", arguments={"query": normalized})) |
| return calls |
|
|
|
|
| async def _execute_web_search( |
| arguments: dict[str, Any], *, client: httpx.AsyncClient |
| ) -> ToolCallRecord: |
| query = str(arguments.get("query", "")).strip() |
| if not query: |
| return ToolCallRecord( |
| name="web_search", |
| arguments=arguments, |
| status="failed", |
| summary="Trūkst query parametra web_search rīkam.", |
| ) |
| try: |
| response = await client.get( |
| DEFAULT_WEB_SEARCH_ENDPOINT, |
| params={ |
| "q": query, |
| "format": "json", |
| "no_redirect": "1", |
| "no_html": "1", |
| "skip_disambig": "1", |
| }, |
| ) |
| response.raise_for_status() |
| except httpx.HTTPError as exc: |
| return ToolCallRecord( |
| name="web_search", |
| arguments=arguments, |
| status="failed", |
| summary=f"Web search neizdevās: {exc}", |
| ) |
| payload = response.json() |
| sources: list[GroundingSource] = [] |
| abstract = str(payload.get("AbstractText", "")).strip() |
| abstract_url = str(payload.get("AbstractURL", "")).strip() |
| if abstract: |
| sources.append( |
| GroundingSource( |
| kind="web_search", |
| label=payload.get("Heading") or query, |
| uri=abstract_url or None, |
| snippet=abstract, |
| ) |
| ) |
| for topic in payload.get("RelatedTopics", [])[:3]: |
| if not isinstance(topic, dict): |
| continue |
| text = str(topic.get("Text", "")).strip() |
| url = str(topic.get("FirstURL", "")).strip() |
| if text: |
| sources.append( |
| GroundingSource( |
| kind="web_search", |
| label=text.split(" - ")[0][:120], |
| uri=url or None, |
| snippet=text[:280], |
| ) |
| ) |
| summary = "Atrasti ārējie avoti aktuālai vai pārbaudāmai informācijai." |
| if not sources: |
| summary = ( |
| "Web search neatgrieza strukturētus rezultātus; gala atbildē jānorāda nenoteiktība." |
| ) |
| return ToolCallRecord( |
| name="web_search", |
| arguments=arguments, |
| status="completed", |
| summary=summary, |
| sources=sources, |
| ) |
|
|
|
|
| async def _execute_web_fetch( |
| arguments: dict[str, Any], *, client: httpx.AsyncClient |
| ) -> ToolCallRecord: |
| url = str(arguments.get("url", "")).strip() |
| parsed = urlparse(url) |
| if not url or parsed.scheme not in {"http", "https"}: |
| return ToolCallRecord( |
| name="web_fetch", |
| arguments=arguments, |
| status="failed", |
| summary="Trūkst derīga http/https URL parametra web_fetch rīkam.", |
| ) |
| try: |
| response = await client.get(url) |
| response.raise_for_status() |
| except httpx.HTTPError as exc: |
| return ToolCallRecord( |
| name="web_fetch", |
| arguments=arguments, |
| status="failed", |
| summary=f"Neizdevās nolasīt ārējo avotu {url}: {exc}", |
| ) |
| content_type = response.headers.get("content-type", "").lower() |
| raw_text = response.text |
| title = None |
| if "html" in content_type: |
| extractor = _HTMLTextExtractor() |
| extractor.feed(raw_text) |
| extractor.close() |
| title = extractor.title or None |
| raw_text = extractor.text |
| elif title_match := HTML_TITLE_PATTERN.search(raw_text): |
| |
| title = html.unescape(title_match.group(1).strip()) or None |
| cleaned = " ".join(html.unescape(raw_text).split())[:DEFAULT_WEB_FETCH_MAX_CHARS] |
| hostname = parsed.netloc or url |
| label = title or hostname |
| summary = f"Nolasīts ārējais avots no {hostname}." |
| if not cleaned: |
| summary = f"Avots {hostname} neatgrieza nolasāmu teksta saturu; jāatzīst nenoteiktība." |
| return ToolCallRecord( |
| name="web_fetch", |
| arguments=arguments, |
| status="completed", |
| summary=summary, |
| sources=[ |
| GroundingSource( |
| kind="web_fetch", |
| label=label[:160], |
| uri=url, |
| snippet=cleaned[:800] or None, |
| ) |
| ] |
| if cleaned |
| else [], |
| ) |
|
|
|
|
| def _execute_workspace_search(arguments: dict[str, Any], root: Path) -> ToolCallRecord: |
| query = str(arguments.get("query", "")).strip() |
| if not query: |
| return ToolCallRecord( |
| name="workspace_search", |
| arguments=arguments, |
| status="failed", |
| summary="Trūkst query parametra workspace_search rīkam.", |
| ) |
|
|
| terms = _extract_workspace_search_terms(query) |
| scored_sources: list[tuple[int, GroundingSource]] = [] |
| for path in sorted(root.rglob("*")): |
| if not path.is_file() or path.suffix.lower() not in WORKSPACE_EXTENSIONS: |
| continue |
| try: |
| content = path.read_text(encoding="utf-8") |
| except (UnicodeDecodeError, OSError): |
| continue |
| relative_path = str(path.relative_to(root)) |
| lowered_path = relative_path.lower() |
| lowered = content.lower() |
| matched_terms = [term for term in terms if term in lowered or term in lowered_path] |
| if not matched_terms: |
| continue |
| line_start = None |
| snippet = "" |
| content_score = 0 |
| for index, line in enumerate(content.splitlines()[:MAX_WORKSPACE_SCAN_LINES], start=1): |
| line_lower = line.lower() |
| line_matches = sum(1 for term in matched_terms if term in line_lower) |
| if line_matches == 0: |
| continue |
| if line_start is None: |
| line_start = index |
| snippet = line.strip() |
| content_score = max(content_score, line_matches) |
| path_match_count = sum(1 for term in matched_terms if term in lowered_path) |
| path_score = path_match_count * PATH_MATCH_WEIGHT |
| score = len(set(matched_terms)) + path_score + content_score |
| scored_sources.append( |
| ( |
| score, |
| GroundingSource( |
| kind="workspace_search", |
| label=relative_path, |
| uri=str(path), |
| snippet=snippet[:280] or None, |
| line_start=line_start, |
| ), |
| ) |
| ) |
| sources = [ |
| source for _, source in sorted(scored_sources, key=lambda item: (-item[0], item[1].label)) |
| ][:MAX_GROUNDING_SOURCES] |
|
|
| summary = "Atrasti atbilstoši repozitorija faili un dokumentācija." |
| if not sources: |
| summary = ( |
| "Repo netika atrasti tieši atbilstoši faili; gala atbildē neapgalvo neko nepārbaudītu." |
| ) |
| return ToolCallRecord( |
| name="workspace_search", |
| arguments=arguments, |
| status="completed", |
| summary=summary, |
| sources=sources, |
| ) |
|
|
|
|
| def _execute_workspace_read(arguments: dict[str, Any], root: Path) -> ToolCallRecord: |
| raw_path = str(arguments.get("path", "")).strip() |
| if not raw_path: |
| return ToolCallRecord( |
| name="workspace_read", |
| arguments=arguments, |
| status="failed", |
| summary="Trūkst path parametra workspace_read rīkam.", |
| ) |
| target = Path(raw_path) |
| target = (root / target).resolve() if not target.is_absolute() else target.resolve() |
| if target != root and root not in target.parents: |
| return ToolCallRecord( |
| name="workspace_read", |
| arguments=arguments, |
| status="failed", |
| summary="Pieprasītais fails ir ārpus atļautās workspace saknes.", |
| ) |
| try: |
| content = target.read_text(encoding="utf-8") |
| except OSError as exc: |
| return ToolCallRecord( |
| name="workspace_read", |
| arguments=arguments, |
| status="failed", |
| summary=f"Neizdevās nolasīt failu: {exc}", |
| ) |
|
|
| start_line = max(1, int(arguments.get("start_line", 1) or 1)) |
| lines = content.splitlines() |
| end_line = min(len(lines), start_line + 39) |
| excerpt = "\n".join(lines[start_line - 1 : end_line]).strip() |
| return ToolCallRecord( |
| name="workspace_read", |
| arguments=arguments, |
| status="completed", |
| summary=( |
| f"Nolasīts fails {target.relative_to(root)} ar fokusētu izgriezumu no {start_line}. līnijas." |
| ), |
| sources=[ |
| GroundingSource( |
| kind="workspace_read", |
| label=str(target.relative_to(root)), |
| uri=str(target), |
| snippet=excerpt[:800] or None, |
| line_start=start_line, |
| ) |
| ], |
| ) |
|
|
|
|
| def _normalize_max_steps(value: int | None) -> int: |
| if value is None: |
| return MAX_TOOL_STEPS_DEFAULT |
| return max(1, min(int(value), MAX_TOOL_STEPS_CAP)) |
|
|
|
|
| def _extract_urls(message: str) -> list[str]: |
| return [match.rstrip(URL_STRIP_CHARS) for match in URL_PATTERN.findall(message)] |
|
|
|
|
| def _extract_workspace_path_candidates(message: str) -> list[str]: |
| candidates: list[str] = [] |
| for raw_token in message.split(): |
| token = raw_token.strip().strip(PATH_STRIP_CHARS) |
| if not token: |
| continue |
| normalized = token.replace("\\", "/") |
| suffix = Path(normalized).suffix.lower() |
| if "/" not in normalized and suffix not in WORKSPACE_EXTENSIONS: |
| continue |
| if normalized.startswith(("http://", "https://")): |
| continue |
| if normalized not in candidates: |
| candidates.append(normalized) |
| return candidates |
|
|
|
|
| def _extract_workspace_search_terms(query: str) -> list[str]: |
| raw_terms = re.split(r"[\s/\\`'\"():,\[\]{}<>]+", query) |
| terms: list[str] = [] |
| for raw_term in raw_terms: |
| term = raw_term.strip().strip(PATH_STRIP_CHARS).lower() |
| if len(term) <= MIN_SEARCH_TERM_LENGTH: |
| continue |
| if term in WORKSPACE_STOPWORDS: |
| continue |
| if term not in terms: |
| terms.append(term) |
| if len(terms) >= MAX_SEARCH_TERMS: |
| break |
| return terms |
|
|
|
|
| def _should_use_workspace_grounding(message: str, workspace_candidates: list[str]) -> bool: |
| normalized = message.strip().lower() |
| if any(token in normalized for token in WORKSPACE_KEYWORDS): |
| return True |
| if "/" in message or "\\" in message or workspace_candidates: |
| return True |
| return any(token in normalized for token in CODE_GROUNDING_KEYWORDS) |
|
|
|
|
| def _should_use_web_search(message: str) -> bool: |
| normalized = message.strip().lower() |
| if not normalized: |
| return False |
| return any(token in normalized for token in (*WEB_SEARCH_KEYWORDS, *WEB_GROUNDING_KEYWORDS)) |
|
|
|
|
| def _resolve_workspace_candidates(root: Path, candidates: list[str]) -> list[Path]: |
| resolved: list[Path] = [] |
| for candidate in candidates: |
| target = (root / candidate).resolve() |
| if not target.exists() or not target.is_file(): |
| continue |
| if root not in target.parents and target != root: |
| continue |
| if target not in resolved: |
| resolved.append(target) |
| return resolved |
|
|
|
|
| def _tool_call_key(call: PlannedToolCall) -> str: |
| return json.dumps( |
| {"name": call.name, "arguments": call.arguments}, |
| sort_keys=True, |
| ensure_ascii=False, |
| ) |
|
|
|
|
| def _follow_up_calls( |
| call: PlannedToolCall, |
| record: ToolCallRecord, |
| ) -> list[PlannedToolCall]: |
| if record.status != "completed" or not record.sources: |
| return [] |
| if call.name == "web_search": |
| return [ |
| PlannedToolCall(name="web_fetch", arguments={"url": source.uri}) |
| for source in record.sources[:MAX_TOOL_FOLLOW_UPS] |
| if source.uri |
| ] |
| if call.name == "workspace_search": |
| return [ |
| PlannedToolCall( |
| name="workspace_read", |
| arguments={ |
| "path": source.uri, |
| "start_line": source.line_start or 1, |
| }, |
| ) |
| for source in record.sources[:MAX_TOOL_FOLLOW_UPS] |
| if source.uri |
| ] |
| return [] |
|
|
|
|
| def _merge_grounding_sources( |
| current: list[GroundingSource], |
| new_sources: list[GroundingSource], |
| ) -> list[GroundingSource]: |
| seen = { |
| (source.kind, source.label, source.uri, source.snippet, source.line_start) |
| for source in current |
| } |
| merged = list(current) |
| for source in new_sources: |
| key = (source.kind, source.label, source.uri, source.snippet, source.line_start) |
| if key in seen: |
| continue |
| seen.add(key) |
| merged.append(source) |
| return merged |
|
|