IntegraChat / backend /api /services /tool_scoring.py
nothingworry's picture
Reasoning traces, smarter tools, deterministic backend tests.
ef83e66
raw
history blame
2.08 kB
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List
from .semantic_encoder import embed_text, cosine_similarity
def _normalize(score: float) -> float:
return max(0.0, min(1.0, score))
@dataclass
class ToolScoringService:
"""
Heuristic + semantic tool fitness scoring.
"""
_domain_prompts: Dict[str, str] = field(default_factory=lambda: {
"rag": "internal company policy, handbook, corporate procedure, proprietary",
"web": "latest external news, public web search, trending topics, live data",
"llm": "casual chit chat, brainstorming, creative writing, general knowledge"
})
_domain_vectors: Dict[str, List[float]] = field(init=False)
def __post_init__(self):
self._domain_vectors = {
name: embed_text(prompt)
for name, prompt in self._domain_prompts.items()
}
def score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]:
embedding = embed_text(message)
rag_sem = cosine_similarity(embedding, self._domain_vectors["rag"])
web_sem = cosine_similarity(embedding, self._domain_vectors["web"])
llm_sem = cosine_similarity(embedding, self._domain_vectors["llm"])
rag_signal = 0.4 * rag_sem + 0.4 * (1 if rag_results else 0) + 0.2 * (1 if intent == "rag" else 0)
web_signal = 0.5 * web_sem + 0.3 * (1 if intent == "web" else 0) + 0.2 * self._freshness_signal(message)
llm_signal = 0.6 * llm_sem + 0.4 * (1 if intent == "general" else 0)
return {
"rag_fitness": round(_normalize(rag_signal), 3),
"web_fitness": round(_normalize(web_signal), 3),
"llm_only": round(_normalize(llm_signal), 3)
}
@staticmethod
def _freshness_signal(message: str) -> float:
tokens = ("news", "today", "latest", "current", "breaking", "update", "recent", "now")
msg = message.lower()
hits = sum(1 for token in tokens if token in msg)
return min(1.0, hits / 3.0)