Spaces:
Sleeping
Sleeping
File size: 6,582 Bytes
17a78b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import logging
from datetime import date
from langchain_core.messages import SystemMessage
from langgraph.graph import END
from src.agent.state import AgentState
from src.agent.prompts import get_system_prompt
from src.config import settings
from src.tools import all_tools
logger = logging.getLogger("cashy.agent")
# Default models per provider
DEFAULT_MODELS = {
"openai": "gpt-5-mini",
"anthropic": "claude-sonnet-4-20250514",
"google": "gemini-2.5-flash",
"huggingface": "meta-llama/Llama-3.3-70B-Instruct",
"free-tier": "Qwen/Qwen2.5-7B-Instruct",
}
# Capture Space's HF token at startup (before BYOK overwrites it)
_SPACE_HF_TOKEN = settings.hf_token
def create_model():
"""Create the LLM chat model with tools bound. Supports multiple providers."""
provider = settings.resolved_provider
if not provider:
raise ValueError(
"No API key configured. Please select a provider and enter your API key in the sidebar."
)
model_name = settings.model_name or DEFAULT_MODELS[provider]
logger.info("Initializing LLM: %s (provider=%s)", model_name, provider)
if provider == "openai":
from langchain_openai import ChatOpenAI
chat_model = ChatOpenAI(
model=model_name,
api_key=settings.openai_api_key,
max_tokens=settings.model_max_tokens,
temperature=settings.model_temperature,
)
elif provider == "anthropic":
from langchain_anthropic import ChatAnthropic
chat_model = ChatAnthropic(
model=model_name,
api_key=settings.anthropic_api_key,
max_tokens=settings.model_max_tokens,
temperature=settings.model_temperature,
)
elif provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
chat_model = ChatGoogleGenerativeAI(
model=model_name,
google_api_key=settings.google_api_key,
max_output_tokens=settings.model_max_tokens,
temperature=settings.model_temperature,
)
elif provider == "free-tier":
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
model_name = DEFAULT_MODELS["free-tier"] # always locked
llm = HuggingFaceEndpoint(
repo_id=model_name,
task="text-generation",
max_new_tokens=settings.model_max_tokens,
huggingfacehub_api_token=_SPACE_HF_TOKEN,
)
chat_model = ChatHuggingFace(llm=llm)
elif provider == "huggingface":
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
llm = HuggingFaceEndpoint(
repo_id=model_name,
provider=settings.hf_inference_provider,
task="text-generation",
max_new_tokens=settings.model_max_tokens,
huggingfacehub_api_token=settings.hf_token,
)
chat_model = ChatHuggingFace(llm=llm)
else:
raise ValueError(f"Unknown LLM provider: {provider}")
tools = _sanitize_tools(all_tools) if provider in ("huggingface", "free-tier") else all_tools
model = chat_model.bind_tools(tools)
logger.info("Model ready with %d tools bound", len(all_tools))
return model
# Module-level model instance (created once)
model_with_tools = None
def get_model():
global model_with_tools
if model_with_tools is None:
model_with_tools = create_model()
return model_with_tools
def reset_model():
"""Clear the cached model so the next call creates a fresh one."""
global model_with_tools
model_with_tools = None
logger.info("Model cache cleared — next query will reinitialize")
def _sanitize_for_latin1(text: str) -> str:
"""Replace non-latin-1 Unicode characters for HuggingFace's HTTP transport."""
result = []
for c in text:
try:
c.encode("latin-1")
result.append(c)
except UnicodeEncodeError:
# Common replacements
if c in ("\u2014", "\u2013"):
result.append("-")
elif c in ("\u201c", "\u201d"):
result.append('"')
elif c in ("\u2018", "\u2019"):
result.append("'")
elif c == "\u2026":
result.append("...")
elif c == "\u2192":
result.append("->")
else:
result.append("?")
return "".join(result)
def _sanitize_tools(tools: list) -> list:
"""Return copies of tools with latin-1 safe descriptions."""
import copy
sanitized = []
for tool in tools:
t = copy.deepcopy(tool)
if hasattr(t, "description"):
t.description = _sanitize_for_latin1(t.description)
if hasattr(t, "args_schema") and t.args_schema:
for field_name, field_info in t.args_schema.model_fields.items():
if field_info.description:
field_info.description = _sanitize_for_latin1(field_info.description)
sanitized.append(t)
return sanitized
def call_model(state: AgentState) -> dict:
"""Invoke the LLM with system prompt and tools."""
model = get_model()
today = date.today()
prompt = get_system_prompt(settings.app_mode).format(today=today.isoformat(), year=today.year)
messages = [SystemMessage(content=prompt)] + state["messages"]
# HuggingFace Inference API requires latin-1 compatible text
if settings.resolved_provider in ("huggingface", "free-tier"):
logger.debug("Sanitizing %d messages for latin-1 compatibility", len(messages))
for msg in messages:
if isinstance(msg.content, str):
msg.content = _sanitize_for_latin1(msg.content)
logger.debug("Calling LLM (%d messages in state)", len(state["messages"]))
response = model.invoke(messages)
if response.tool_calls:
tool_names = [tc["name"] for tc in response.tool_calls]
logger.info("LLM requested tools: %s", ", ".join(tool_names))
for tc in response.tool_calls:
logger.debug(" -> %s(%s)", tc["name"], tc["args"])
else:
logger.info("LLM final response (%d chars)", len(response.content))
return {"messages": [response]}
def should_continue(state: AgentState) -> str:
"""Route to tools if the model made tool calls, otherwise end."""
last_message = state["messages"][-1]
if last_message.tool_calls:
logger.debug("Routing to tools node")
return "tools"
logger.debug("Routing to END")
return END
|