Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
158fae7
1
Parent(s):
503a0b6
Further simplify ToolCallingLLM
Browse files- app.py +2 -2
- graph.py +2 -4
- mods/tool_calling_llm.py +9 -115
app.py
CHANGED
|
@@ -401,7 +401,7 @@ with gr.Blocks(
|
|
| 401 |
info_text = f"""
|
| 402 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 403 |
**Features:** RAG, today's date, hybrid search (dense+sparse), thinking output (local),
|
| 404 |
-
multiple retrievals
|
| 405 |
**Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
|
| 406 |
"""
|
| 407 |
return info_text
|
|
@@ -453,7 +453,7 @@ with gr.Blocks(
|
|
| 453 |
)
|
| 454 |
multi_turn_questions = [
|
| 455 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
| 456 |
-
"Did those authors report bugs before 2025?",
|
| 457 |
]
|
| 458 |
gr.Examples(
|
| 459 |
examples=[[q] for q in multi_turn_questions],
|
|
|
|
| 401 |
info_text = f"""
|
| 402 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 403 |
**Features:** RAG, today's date, hybrid search (dense+sparse), thinking output (local),
|
| 404 |
+
multiple retrievals (remote), citations output (remote), chat memory.
|
| 405 |
**Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
|
| 406 |
"""
|
| 407 |
return info_text
|
|
|
|
| 453 |
)
|
| 454 |
multi_turn_questions = [
|
| 455 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
| 456 |
+
"Did those authors report bugs before 2025? /think",
|
| 457 |
]
|
| 458 |
gr.Examples(
|
| 459 |
examples=[[q] for q in multi_turn_questions],
|
graph.py
CHANGED
|
@@ -71,7 +71,7 @@ def normalize_messages(messages):
|
|
| 71 |
return messages
|
| 72 |
|
| 73 |
|
| 74 |
-
def ToolifyHF(chat_model, system_message
|
| 75 |
"""
|
| 76 |
Get a Hugging Face model ready for bind_tools().
|
| 77 |
"""
|
|
@@ -86,8 +86,6 @@ def ToolifyHF(chat_model, system_message, system_message_suffix=""):
|
|
| 86 |
chat_model = HuggingFaceWithTools(
|
| 87 |
llm=chat_model.llm,
|
| 88 |
tool_system_prompt_template=tool_system_prompt_template,
|
| 89 |
-
# Suffix is for any additional context (not templated)
|
| 90 |
-
system_message_suffix=system_message_suffix,
|
| 91 |
)
|
| 92 |
|
| 93 |
return chat_model
|
|
@@ -195,7 +193,7 @@ def BuildGraph(
|
|
| 195 |
if is_local:
|
| 196 |
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 197 |
query_model = ToolifyHF(
|
| 198 |
-
chat_model, query_prompt(chat_model, think=think_query)
|
| 199 |
).bind_tools([retrieve_emails])
|
| 200 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 201 |
generate_model = chat_model
|
|
|
|
| 71 |
return messages
|
| 72 |
|
| 73 |
|
| 74 |
+
def ToolifyHF(chat_model, system_message):
|
| 75 |
"""
|
| 76 |
Get a Hugging Face model ready for bind_tools().
|
| 77 |
"""
|
|
|
|
| 86 |
chat_model = HuggingFaceWithTools(
|
| 87 |
llm=chat_model.llm,
|
| 88 |
tool_system_prompt_template=tool_system_prompt_template,
|
|
|
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
return chat_model
|
|
|
|
| 193 |
if is_local:
|
| 194 |
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 195 |
query_model = ToolifyHF(
|
| 196 |
+
chat_model, query_prompt(chat_model, think=think_query)
|
| 197 |
).bind_tools([retrieve_emails])
|
| 198 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 199 |
generate_model = chat_model
|
mods/tool_calling_llm.py
CHANGED
|
@@ -49,51 +49,6 @@ You must always select one of the above tools and respond with only a JSON objec
|
|
| 49 |
""" # noqa: E501
|
| 50 |
|
| 51 |
|
| 52 |
-
def _is_pydantic_class(obj: Any) -> bool:
|
| 53 |
-
"""
|
| 54 |
-
Checks if the tool provided is a Pydantic class.
|
| 55 |
-
"""
|
| 56 |
-
return isinstance(obj, type) and (
|
| 57 |
-
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def _is_pydantic_object(obj: Any) -> bool:
|
| 62 |
-
"""
|
| 63 |
-
Checks if the tool provided is a Pydantic object.
|
| 64 |
-
"""
|
| 65 |
-
return isinstance(obj, BaseModel)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def RawJSONDecoder(index):
|
| 69 |
-
class _RawJSONDecoder(json.JSONDecoder):
|
| 70 |
-
end = None
|
| 71 |
-
|
| 72 |
-
def decode(self, s, *_):
|
| 73 |
-
data, self.__class__.end = self.raw_decode(s, index)
|
| 74 |
-
return data
|
| 75 |
-
|
| 76 |
-
return _RawJSONDecoder
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def extract_json(s, index=0):
|
| 80 |
-
while (index := s.find("{", index)) != -1:
|
| 81 |
-
try:
|
| 82 |
-
yield json.loads(s, cls=(decoder := RawJSONDecoder(index)))
|
| 83 |
-
index = decoder.end
|
| 84 |
-
except json.JSONDecodeError:
|
| 85 |
-
index += 1
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def parse_json_garbage(s: str) -> Any:
|
| 89 |
-
# Find the first occurrence of a JSON opening brace or bracket
|
| 90 |
-
candidates = list(extract_json(s))
|
| 91 |
-
if len(candidates) >= 1:
|
| 92 |
-
return candidates[0]
|
| 93 |
-
|
| 94 |
-
raise ValueError("Not a valid JSON string")
|
| 95 |
-
|
| 96 |
-
|
| 97 |
def extract_think(content):
|
| 98 |
# Added by Cursor 20250726 jmd
|
| 99 |
# Extract content within <think>...</think>
|
|
@@ -162,7 +117,7 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 162 |
|
| 163 |
Tool calling:
|
| 164 |
```
|
| 165 |
-
from
|
| 166 |
|
| 167 |
class GetWeather(BaseModel):
|
| 168 |
'''Get the current weather in a given location'''
|
|
@@ -188,78 +143,25 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 188 |
""" # noqa: E501
|
| 189 |
|
| 190 |
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
|
| 191 |
-
# Suffix to add to the system prompt that is not templated 20250717 jmd
|
| 192 |
-
system_message_suffix: str = ""
|
| 193 |
-
|
| 194 |
-
override_bind_tools: bool = True
|
| 195 |
|
| 196 |
def __init__(self, **kwargs: Any) -> None:
|
| 197 |
-
override_bind_tools = True
|
| 198 |
-
if "override_bind_tools" in kwargs:
|
| 199 |
-
override_bind_tools = kwargs["override_bind_tools"]
|
| 200 |
-
del kwargs["override_bind_tools"]
|
| 201 |
super().__init__(**kwargs)
|
| 202 |
-
self.override_bind_tools = override_bind_tools
|
| 203 |
-
|
| 204 |
-
def bind_tools(
|
| 205 |
-
self,
|
| 206 |
-
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
| 207 |
-
**kwargs: Any,
|
| 208 |
-
) -> Runnable[LanguageModelInput, BaseMessage]:
|
| 209 |
-
if self.override_bind_tools:
|
| 210 |
-
return self.bind(functions=tools, **kwargs)
|
| 211 |
-
else:
|
| 212 |
-
return super().bind_tools(tools, **kwargs)
|
| 213 |
|
| 214 |
def _generate_system_message_and_functions(
|
| 215 |
self,
|
| 216 |
kwargs: Dict[str, Any],
|
| 217 |
) -> Tuple[BaseMessage, List]:
|
| 218 |
-
functions = kwargs.get("tools",
|
| 219 |
-
functions = [
|
| 220 |
-
(
|
| 221 |
-
fn["function"]
|
| 222 |
-
if (
|
| 223 |
-
not _is_pydantic_class(fn)
|
| 224 |
-
and not _is_pydantic_object(fn)
|
| 225 |
-
and "name" not in fn.keys()
|
| 226 |
-
and "function" in fn.keys()
|
| 227 |
-
and "name" in fn["function"].keys()
|
| 228 |
-
)
|
| 229 |
-
else fn
|
| 230 |
-
)
|
| 231 |
-
for fn in functions
|
| 232 |
-
]
|
| 233 |
-
|
| 234 |
-
# langchain_openai/chat_models/base.py:
|
| 235 |
-
# NOTE: Using bind_tools is recommended instead, as the `functions` and
|
| 236 |
-
# `function_call` request parameters are officially marked as
|
| 237 |
-
# deprecated by OpenAI.
|
| 238 |
-
|
| 239 |
-
# if "functions" in kwargs:
|
| 240 |
-
# del kwargs["functions"]
|
| 241 |
-
# if "function_call" in kwargs:
|
| 242 |
-
# functions = [
|
| 243 |
-
# fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
| 244 |
-
# ]
|
| 245 |
-
# if not functions:
|
| 246 |
-
# raise ValueError(
|
| 247 |
-
# "If `function_call` is specified, you must also pass a "
|
| 248 |
-
# "matching function in `functions`."
|
| 249 |
-
# )
|
| 250 |
-
# del kwargs["function_call"]
|
| 251 |
|
|
|
|
| 252 |
functions = [convert_to_openai_tool(fn) for fn in functions]
|
|
|
|
| 253 |
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
| 254 |
self.tool_system_prompt_template
|
| 255 |
)
|
| 256 |
system_message = system_message_prompt_template.format(
|
| 257 |
tools=json.dumps(functions, indent=2)
|
| 258 |
)
|
| 259 |
-
# Add extra context after the formatted system message 20250717 jmd
|
| 260 |
-
system_message = SystemMessage(
|
| 261 |
-
system_message.content + self.system_message_suffix
|
| 262 |
-
)
|
| 263 |
return system_message, functions
|
| 264 |
|
| 265 |
def _process_response(
|
|
@@ -275,16 +177,8 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 275 |
try:
|
| 276 |
parsed_json_result = json.loads(post_think)
|
| 277 |
except json.JSONDecodeError:
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
print(post_think)
|
| 281 |
-
parsed_json_result = parse_json_garbage(post_think)
|
| 282 |
-
except Exception:
|
| 283 |
-
# Return entire response if JSON is missing or wasn't parsed
|
| 284 |
-
return AIMessage(content=response_message.content)
|
| 285 |
-
|
| 286 |
-
print("parsed_json_result")
|
| 287 |
-
print(parsed_json_result)
|
| 288 |
|
| 289 |
# Get tool name from output
|
| 290 |
called_tool_name = (
|
|
@@ -299,7 +193,7 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 299 |
)
|
| 300 |
if called_tool is None:
|
| 301 |
# Issue a warning and return the generated content 20250727 jmd
|
| 302 |
-
warnings.warn(f"Called tool ({
|
| 303 |
return AIMessage(content=response_message.content)
|
| 304 |
|
| 305 |
# Get tool arguments from output
|
|
@@ -314,7 +208,7 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 314 |
)
|
| 315 |
|
| 316 |
# Put together response message
|
| 317 |
-
|
| 318 |
content=f"<think>\n{think_text}\n</think>",
|
| 319 |
tool_calls=[
|
| 320 |
ToolCall(
|
|
@@ -325,7 +219,7 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 325 |
],
|
| 326 |
)
|
| 327 |
|
| 328 |
-
return
|
| 329 |
|
| 330 |
def _generate(
|
| 331 |
self,
|
|
|
|
| 49 |
""" # noqa: E501
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def extract_think(content):
|
| 53 |
# Added by Cursor 20250726 jmd
|
| 54 |
# Extract content within <think>...</think>
|
|
|
|
| 117 |
|
| 118 |
Tool calling:
|
| 119 |
```
|
| 120 |
+
from pydantic import BaseModel, Field
|
| 121 |
|
| 122 |
class GetWeather(BaseModel):
|
| 123 |
'''Get the current weather in a given location'''
|
|
|
|
| 143 |
""" # noqa: E501
|
| 144 |
|
| 145 |
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def __init__(self, **kwargs: Any) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def _generate_system_message_and_functions(
|
| 151 |
self,
|
| 152 |
kwargs: Dict[str, Any],
|
| 153 |
) -> Tuple[BaseMessage, List]:
|
| 154 |
+
functions = kwargs.get("tools", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# Convert functions to OpenAI tool schema
|
| 157 |
functions = [convert_to_openai_tool(fn) for fn in functions]
|
| 158 |
+
# Create system message with tool descriptions
|
| 159 |
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
| 160 |
self.tool_system_prompt_template
|
| 161 |
)
|
| 162 |
system_message = system_message_prompt_template.format(
|
| 163 |
tools=json.dumps(functions, indent=2)
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return system_message, functions
|
| 166 |
|
| 167 |
def _process_response(
|
|
|
|
| 177 |
try:
|
| 178 |
parsed_json_result = json.loads(post_think)
|
| 179 |
except json.JSONDecodeError:
|
| 180 |
+
# Return entire response if JSON wasn't parsed (or is missing)
|
| 181 |
+
return AIMessage(content=response_message.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Get tool name from output
|
| 184 |
called_tool_name = (
|
|
|
|
| 193 |
)
|
| 194 |
if called_tool is None:
|
| 195 |
# Issue a warning and return the generated content 20250727 jmd
|
| 196 |
+
warnings.warn(f"Called tool ({called_tool_name}) not in functions list")
|
| 197 |
return AIMessage(content=response_message.content)
|
| 198 |
|
| 199 |
# Get tool arguments from output
|
|
|
|
| 208 |
)
|
| 209 |
|
| 210 |
# Put together response message
|
| 211 |
+
response_message = AIMessage(
|
| 212 |
content=f"<think>\n{think_text}\n</think>",
|
| 213 |
tool_calls=[
|
| 214 |
ToolCall(
|
|
|
|
| 219 |
],
|
| 220 |
)
|
| 221 |
|
| 222 |
+
return response_message
|
| 223 |
|
| 224 |
def _generate(
|
| 225 |
self,
|