File size: 6,582 Bytes
17a78b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import logging
from datetime import date
from langchain_core.messages import SystemMessage
from langgraph.graph import END

from src.agent.state import AgentState
from src.agent.prompts import get_system_prompt
from src.config import settings
from src.tools import all_tools

logger = logging.getLogger("cashy.agent")

# Default models per provider
DEFAULT_MODELS = {
    "openai": "gpt-5-mini",
    "anthropic": "claude-sonnet-4-20250514",
    "google": "gemini-2.5-flash",
    "huggingface": "meta-llama/Llama-3.3-70B-Instruct",
    "free-tier": "Qwen/Qwen2.5-7B-Instruct",
}

# Capture Space's HF token at startup (before BYOK overwrites it)
_SPACE_HF_TOKEN = settings.hf_token


def create_model():
    """Create the LLM chat model with tools bound. Supports multiple providers."""
    provider = settings.resolved_provider
    if not provider:
        raise ValueError(
            "No API key configured. Please select a provider and enter your API key in the sidebar."
        )
    model_name = settings.model_name or DEFAULT_MODELS[provider]

    logger.info("Initializing LLM: %s (provider=%s)", model_name, provider)

    if provider == "openai":
        from langchain_openai import ChatOpenAI

        chat_model = ChatOpenAI(
            model=model_name,
            api_key=settings.openai_api_key,
            max_tokens=settings.model_max_tokens,
            temperature=settings.model_temperature,
        )

    elif provider == "anthropic":
        from langchain_anthropic import ChatAnthropic

        chat_model = ChatAnthropic(
            model=model_name,
            api_key=settings.anthropic_api_key,
            max_tokens=settings.model_max_tokens,
            temperature=settings.model_temperature,
        )

    elif provider == "google":
        from langchain_google_genai import ChatGoogleGenerativeAI

        chat_model = ChatGoogleGenerativeAI(
            model=model_name,
            google_api_key=settings.google_api_key,
            max_output_tokens=settings.model_max_tokens,
            temperature=settings.model_temperature,
        )

    elif provider == "free-tier":
        from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

        model_name = DEFAULT_MODELS["free-tier"]  # always locked
        llm = HuggingFaceEndpoint(
            repo_id=model_name,
            task="text-generation",
            max_new_tokens=settings.model_max_tokens,
            huggingfacehub_api_token=_SPACE_HF_TOKEN,
        )
        chat_model = ChatHuggingFace(llm=llm)

    elif provider == "huggingface":
        from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

        llm = HuggingFaceEndpoint(
            repo_id=model_name,
            provider=settings.hf_inference_provider,
            task="text-generation",
            max_new_tokens=settings.model_max_tokens,
            huggingfacehub_api_token=settings.hf_token,
        )
        chat_model = ChatHuggingFace(llm=llm)

    else:
        raise ValueError(f"Unknown LLM provider: {provider}")

    tools = _sanitize_tools(all_tools) if provider in ("huggingface", "free-tier") else all_tools
    model = chat_model.bind_tools(tools)
    logger.info("Model ready with %d tools bound", len(all_tools))
    return model


# Module-level model instance (created once)
model_with_tools = None


def get_model():
    global model_with_tools
    if model_with_tools is None:
        model_with_tools = create_model()
    return model_with_tools


def reset_model():
    """Clear the cached model so the next call creates a fresh one."""
    global model_with_tools
    model_with_tools = None
    logger.info("Model cache cleared — next query will reinitialize")


def _sanitize_for_latin1(text: str) -> str:
    """Replace non-latin-1 Unicode characters for HuggingFace's HTTP transport."""
    result = []
    for c in text:
        try:
            c.encode("latin-1")
            result.append(c)
        except UnicodeEncodeError:
            # Common replacements
            if c in ("\u2014", "\u2013"):
                result.append("-")
            elif c in ("\u201c", "\u201d"):
                result.append('"')
            elif c in ("\u2018", "\u2019"):
                result.append("'")
            elif c == "\u2026":
                result.append("...")
            elif c == "\u2192":
                result.append("->")
            else:
                result.append("?")
    return "".join(result)


def _sanitize_tools(tools: list) -> list:
    """Return copies of tools with latin-1 safe descriptions."""
    import copy
    sanitized = []
    for tool in tools:
        t = copy.deepcopy(tool)
        if hasattr(t, "description"):
            t.description = _sanitize_for_latin1(t.description)
        if hasattr(t, "args_schema") and t.args_schema:
            for field_name, field_info in t.args_schema.model_fields.items():
                if field_info.description:
                    field_info.description = _sanitize_for_latin1(field_info.description)
        sanitized.append(t)
    return sanitized


def call_model(state: AgentState) -> dict:
    """Invoke the LLM with system prompt and tools."""
    model = get_model()
    today = date.today()
    prompt = get_system_prompt(settings.app_mode).format(today=today.isoformat(), year=today.year)
    messages = [SystemMessage(content=prompt)] + state["messages"]

    # HuggingFace Inference API requires latin-1 compatible text
    if settings.resolved_provider in ("huggingface", "free-tier"):
        logger.debug("Sanitizing %d messages for latin-1 compatibility", len(messages))
        for msg in messages:
            if isinstance(msg.content, str):
                msg.content = _sanitize_for_latin1(msg.content)

    logger.debug("Calling LLM (%d messages in state)", len(state["messages"]))

    response = model.invoke(messages)

    if response.tool_calls:
        tool_names = [tc["name"] for tc in response.tool_calls]
        logger.info("LLM requested tools: %s", ", ".join(tool_names))
        for tc in response.tool_calls:
            logger.debug("  -> %s(%s)", tc["name"], tc["args"])
    else:
        logger.info("LLM final response (%d chars)", len(response.content))

    return {"messages": [response]}


def should_continue(state: AgentState) -> str:
    """Route to tools if the model made tool calls, otherwise end."""
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        logger.debug("Routing to tools node")
        return "tools"
    logger.debug("Routing to END")
    return END