File size: 2,323 Bytes
ef83e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddc5c21
ef83e66
 
ddc5c21
 
 
ef83e66
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List

from .semantic_encoder import embed_text, cosine_similarity


def _normalize(score: float) -> float:
    return max(0.0, min(1.0, score))


@dataclass
class ToolScoringService:
    """
    Heuristic + semantic tool fitness scoring.
    """

    _domain_prompts: Dict[str, str] = field(default_factory=lambda: {
        "rag": "internal company policy, handbook, corporate procedure, proprietary",
        "web": "latest external news, public web search, trending topics, live data",
        "llm": "casual chit chat, brainstorming, creative writing, general knowledge"
    })
    _domain_vectors: Dict[str, List[float]] = field(init=False)

    def __post_init__(self):
        self._domain_vectors = {
            name: embed_text(prompt)
            for name, prompt in self._domain_prompts.items()
        }

    def score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]:
        embedding = embed_text(message)
        rag_sem = cosine_similarity(embedding, self._domain_vectors["rag"])
        web_sem = cosine_similarity(embedding, self._domain_vectors["web"])
        llm_sem = cosine_similarity(embedding, self._domain_vectors["llm"])

        rag_signal = 0.4 * rag_sem + 0.4 * (1 if rag_results else 0) + 0.2 * (1 if intent == "rag" else 0)
        web_signal = 0.5 * web_sem + 0.3 * (1 if intent == "web" else 0) + 0.2 * self._freshness_signal(message)
        llm_signal = 0.6 * llm_sem + 0.4 * (1 if intent == "general" else 0)

        return {
            "rag_fitness": round(_normalize(rag_signal), 3),
            "web_fitness": round(_normalize(web_signal), 3),
            "llm_only": round(_normalize(llm_signal), 3)
        }

    @staticmethod
    def _freshness_signal(message: str) -> float:
        tokens = ("news", "today", "latest", "current", "breaking", "update", "recent", "now", "trending", "happening", "what's new", "what is new")
        msg = message.lower()
        hits = sum(1 for token in tokens if token in msg)
        # Boost score for news-related queries
        if "news" in msg or "breaking" in msg or "latest" in msg:
            return min(1.0, 0.7 + (hits * 0.1))  # Start at 0.7 for news queries
        return min(1.0, hits / 3.0)