Spaces:
Sleeping
Sleeping
| from agent import Agent | |
| import os | |
| import json | |
| from litellm import completion | |
| from tools.web_search import GoogleClaimSearch ## NOTE: (optional) custom tool for web search | |
| from tools.address_locator import GoogleGeocodeValidate ## NOTE: (optional) custom tool for address validation | |
| from typing import List, Dict | |
| import re | |
| from utils import parse_output | |
| import time | |
| from pydantic import BaseModel, Field | |
| import logging | |
| import litellm | |
| from litellm import completion, completion_cost | |
| litellm.drop_params = True | |
| class VerificationResult(BaseModel): | |
| verify: str | |
| evidence: List[str] | |
| result: str | |
| class Entities(BaseModel): | |
| entity: str | |
| claim: str | |
| class EntitiesList(BaseModel): | |
| verify: list[Entities] | |
| class SimpleAgent(Agent): | |
| def __init__(self, | |
| model, | |
| name, | |
| description, | |
| keep_history=True, | |
| is_local=False, | |
| **gen_kwargs): | |
| super().__init__(model, name, description) | |
| self.keep_history = keep_history | |
| self.is_local = is_local | |
| self.gen_kwargs = gen_kwargs | |
| if self.is_local: | |
| self.model = "openai/" + self.model | |
| self.gen_kwargs['base_url'] = "http://localhost:8000/v1" | |
| self.gen_kwargs['extra_body'] = { | |
| "thinking_budget":512 | |
| } | |
| def chat(self, prompt) -> tuple[str, str]: | |
| self.store_chat("user", prompt) | |
| for _ in range(self.max_retry): | |
| try: | |
| res = completion( | |
| model=self.model, | |
| messages=self.history, | |
| **self.gen_kwargs | |
| ) | |
| res_json = res.choices[0].message.model_dump() | |
| text = res_json['content'].strip() | |
| if not self.is_local: | |
| self.cost += completion_cost(completion_response=res) | |
| if 'reasoning_content' in res_json and res_json['reasoning_content']: | |
| reasoning = res_json['reasoning_content'] | |
| else: | |
| reasoning = None | |
| self.store_chat("assistant", text, reasoning) | |
| if not self.keep_history: | |
| self.history = [self.history[0]] | |
| return text, reasoning | |
| except Exception as e: | |
| logging.exception(f"Retrying ({_}/{self.max_retry}) … {e}") | |
| raise RuntimeError("Model failed after max_retries") | |
| class WebSearchAgent(Agent): | |
| def __init__(self, model, name, description, reasoning_effort="disable"): | |
| super().__init__(model, name, description) | |
| self.reasoning_effort = reasoning_effort | |
| self.tools_map = { | |
| "google_claim_search":GoogleClaimSearch( | |
| api_key=os.environ["CUSTOM_SEARCH_API_KEY"], | |
| cx=os.environ["GOOGLE_CX_ID"] | |
| ), | |
| "google_geocode_validate": GoogleGeocodeValidate( | |
| api_key=os.environ["GOOGLE_MAP_API_KEY"] | |
| ) | |
| } | |
| self.tools = [tool.get_info() for tool in self.tools_map.values()] | |
| self.max_retry = 3 | |
| self.tool_cost = 0.0 | |
| self.tool_calls_count = { | |
| "google_claim_search": 0, | |
| "google_geocode_validate": 0 | |
| } | |
| def chat(self, entity_dicts: list, std_date: str): # only works for google search tool | |
| for _ in range(self.max_retry): | |
| try: | |
| stored_messages = [self.history[0]] # system message | |
| json_results = [] | |
| for entity_dict in entity_dicts: | |
| assert entity_dict, "Entity dictionary is empty." | |
| prompt = f"Claim: {entity_dict['claim']}\nEntity: {entity_dict['entity']}\nCutoff date: {std_date}" | |
| stored_messages.append({"role": "user", "content": prompt}) | |
| stored_messages, result = self._tool_call(stored_messages) | |
| # logging.info(f"<DEBUG>: `result` type - {type(result)}") | |
| if isinstance(result, str): | |
| if self._extract_dict_from_string(result): | |
| # skip verification if the result is already in the expected output format | |
| result_dict = json.loads(result, strict=False) # parse the result | |
| v, e, r = result_dict.get('verify', None), result_dict.get('evidence', None), result_dict.get('result', None) | |
| if v is None or e is None or r is None: | |
| raise ValueError(f"Invalid response format from model: {result}") | |
| else: | |
| # print(f"Unexpected result format: {result}") | |
| raise ValueError(f"Unexpected result format: {result}") | |
| elif isinstance(result, list): # tool call result | |
| if not result: | |
| raise ValueError("Empty result list returned from tool call.") | |
| text = self._verify(stored_messages) | |
| v, e, r = VerificationResult.model_validate_json(text).model_dump().values() | |
| else: | |
| raise ValueError(f"Unexpected result type and value: {type(result)}. Expected str or list.") | |
| if v is None or r is None or e is None: | |
| raise ValueError(f"Invalid response format from model: {result}") | |
| tool_name = stored_messages[-1].get('name', '') | |
| result_summary = {} | |
| if tool_name == "google_claim_search": | |
| result_summary['tool'] = "google_claim_search" | |
| result_summary['search_results'] = [] | |
| for sr in result: | |
| if not sr: | |
| continue | |
| result_summary['search_results'].append({"title": sr["title"], "link": sr["link"]}) | |
| elif tool_name == "google_geocode_validate": | |
| result_summary['tool'] = "google_geocode_validate" | |
| result_summary['search_results'] = result | |
| else: # possibly a immediate response from the model | |
| result_summary['tool'] = "none" | |
| result_summary['search_results'] = ["none"] | |
| json_result = { | |
| "entity": entity_dict["entity"], | |
| "claim": entity_dict['claim'], | |
| "search_result": result_summary, | |
| "verification": v, | |
| "evidence": e, | |
| "result": r | |
| } | |
| stored_messages = [self.history[0]] # reset the history for the next entity | |
| json_results.append(json_result) | |
| return json.dumps(json_results) | |
| except Exception as e: | |
| logging.exception(f"Retrying ({_+1}/{self.max_retry}) ... {e}") | |
| raise RuntimeError("Model failed after max_retries") | |
| def _tool_call(self, stored_messages: list): | |
| # stored_messages = [self.history[0], {"role": "user", "content": prompt}] | |
| kwargs = { | |
| "tool_choice": "auto", # {'type':'function', 'function': {'name': self.tools[0]['function']['name']}}, | |
| "tools": self.tools, | |
| "reasoning_effort": self.reasoning_effort | |
| } | |
| res = completion( | |
| model=self.model, | |
| messages=stored_messages, # system message and the last user message (current input) | |
| **kwargs | |
| ) | |
| self.cost += completion_cost(completion_response=res) | |
| message = res.choices[0].message.model_dump() | |
| # print(message) | |
| if 'tool_calls' in message and message['tool_calls']: | |
| tool_call = message['tool_calls'][0] | |
| tool = tool_call["function"]["name"] | |
| kwargs = json.loads(tool_call["function"]["arguments"]) | |
| # print(f"Tool call arguments: {kwargs}") | |
| result = self.tools_map[tool].invoke(**kwargs) | |
| ### cost and tool call count ### | |
| if tool == "google_claim_search": | |
| tmp_result = json.loads(result) | |
| for item in tmp_result: | |
| if not isinstance(item['text_block'], str) or not item['text_block'].startswith("Search failure"): | |
| self.tool_calls_count["google_claim_search"] += 1 | |
| self.tool_cost += self._tool_call_pricing("google_claim_search") | |
| elif tool == "google_geocode_validate": | |
| tmp_result = json.loads(result)[0] | |
| if any(list(tmp_result.values())): # error: all values are None | |
| self.tool_calls_count["google_geocode_validate"] += 1 | |
| self.tool_cost += self._tool_call_pricing("google_geocode_validate") | |
| time.sleep(0.5) # to avoid rate limit issues | |
| stored_messages.extend( | |
| [ | |
| message, | |
| { | |
| "role": "tool", | |
| "tool_call_id": tool_call["id"], | |
| "name": tool_call['function']["name"], | |
| "content": result | |
| } | |
| ] | |
| ) | |
| return stored_messages, json.loads(result) | |
| elif message["content"].strip(): | |
| text = message["content"].strip() | |
| stored_messages.append({"role": "assistant", "content": text}) | |
| # logging.info(f"Model response without tool call: {text}") | |
| return stored_messages, text | |
| else: | |
| raise ValueError("No tool call found in the response from the model.") | |
| def _extract_dict_from_string(self, input_string): | |
| start_index = input_string.find('{') | |
| end_index = input_string.rfind('}') | |
| if start_index != -1 and end_index != -1 and start_index < end_index: | |
| return input_string[start_index:end_index + 1] | |
| else: | |
| return None | |
| def claim_search(self, claim: str): ### benchmark evaluation | |
| prompt = f"Claim: {claim}" | |
| for _ in range(self.max_retry): | |
| try: | |
| stored_messages = [self.history[0]] | |
| stored_messages.append({"role": "user", "content": prompt}) | |
| stored_messages, result = self._tool_call(stored_messages) | |
| # print(f"[result]: {result}") | |
| if isinstance(result, str): | |
| if '<verify>' in result: | |
| # skip verification if the result is already in the expected output format | |
| v, e, r = self._parse_result(result) | |
| else: | |
| # print(f"Unexpected result format: {result}") | |
| raise ValueError(f"Unexpected result format: {result}") | |
| elif isinstance(result, list): # tool call result | |
| if not result: | |
| raise ValueError("Empty result list returned from tool call.") | |
| text = self._verify(stored_messages) | |
| v, e, r = VerificationResult.model_validate(text).model_dump().values() | |
| else: | |
| raise ValueError(f"Unexpected result type: {type(result)}. Expected str or list.") | |
| if v is None or r is None or e is None: | |
| raise ValueError(f"Invalid response format from model: {text}") | |
| tool_name = stored_messages[-1].get('name', '') | |
| result_summary = {} | |
| if tool_name == "google_claim_search": | |
| result_summary['tool'] = "google_claim_search" | |
| result_summary['search_results'] = [] | |
| for sr in result: | |
| if not sr: | |
| continue | |
| result_summary['search_results'].append({"title": sr["title"], "link": sr["link"]}) | |
| elif tool_name == "google_geocode_validate": | |
| result_summary['tool'] = "google_geocode_validate" | |
| result_summary['search_results'] = result | |
| else: # possibly a immediate response from the model | |
| result_summary['tool'] = "none" | |
| result_summary['search_results'] = ["none"] | |
| result = { | |
| "claim": claim, | |
| "verification": v, | |
| "evidence": e, | |
| "result": r | |
| } | |
| return result | |
| except Exception as e: | |
| logging.error(f"Retrying ({_+1}/{self.max_retry}) … {e}") | |
| raise RuntimeError("Model failed after max_retries") | |
| def _verify(self, messages): | |
| res = completion( | |
| model=self.model, | |
| messages=messages, | |
| reasoning_effort=self.reasoning_effort, | |
| response_format=VerificationResult, | |
| tool_choice="none" | |
| ) | |
| self.cost += completion_cost(completion_response=res) | |
| res_json = res.choices[0].message.model_dump() | |
| text = res_json["content"].strip() | |
| return text | |
| def _parse_result(self, text: str) -> tuple: | |
| matches = re.match(r"<verify>([\s\S]+?)</verify>\s*<evidence>([\s\S]+?)</evidence>\s*<result>([\s\S]+?)</result>", text) | |
| if not matches: | |
| return None, None, None | |
| verify = matches.group(1).strip() | |
| evidence = matches.group(2).strip() | |
| result = matches.group(3).strip() | |
| return verify, evidence, result | |
| def _tool_call_pricing(self, tool: str) -> float: | |
| """ | |
| single request pricing for Google Custom Search JSON API and Google Geocoding API. | |
| Custom Search JSON API provides 100 search queries per day for free. If you need more, you may sign up for billing in the API Console. | |
| Additional requests cost $5 per 1000 queries, up to 10k queries per day. | |
| """ | |
| if tool == "google_claim_search": | |
| if self.tool_calls_count["google_claim_search"] <= 100: | |
| return 0.0 | |
| else: # NOTE: daily limit is 10k queries! | |
| return 5.0 / 1000 | |
| elif tool == "google_geocode_validate": | |
| if self.tool_calls_count["google_geocode_validate"] <= 10000: | |
| return 0.0 | |
| elif self.tool_calls_count["google_geocode_validate"] > 10000 and self.tool_calls_count["google_geocode_validate"] <= 100000: | |
| return 5.0 / 1000 | |
| elif self.tool_calls_count["google_geocode_validate"] > 100000 and self.tool_calls_count["google_geocode_validate"] <= 500000: | |
| return 4.0 / 1000 | |
| elif self.tool_calls_count["google_geocode_validate"] > 500000 and self.tool_calls_count["google_geocode_validate"] <= 1000000: | |
| return 3.0 / 1000 | |
| elif self.tool_calls_count["google_geocode_validate"] > 1000000 and self.tool_calls_count["google_geocode_validate"] <= 5000000: | |
| return 1.5 / 1000 | |
| else: | |
| return 0.38 / 1000 | |
| class EntityExtractor(Agent): | |
| def __init__(self, model, name, description, reasoning_effort="disable"): | |
| super().__init__(model, name, description) | |
| self.reasoning_effort = reasoning_effort | |
| self.input_format = "Question: {question}\nResponse: {response}" | |
| def chat(self, question: str, response: str): | |
| # self.store_chat("user", prompt) | |
| prompt = self.input_format.format(question=question, response=response) | |
| for _ in range(self.max_retry): | |
| try: | |
| messages = [self.history[0], {"role": "user", "content": prompt}] # system message | |
| res = self.call_completion( | |
| model=self.model, | |
| messages=messages, | |
| response_format=EntitiesList, | |
| reasoning_effort=self.reasoning_effort, | |
| ) | |
| try: | |
| validated = EntitiesList.model_validate_json(res) | |
| if not validated.verify: | |
| logging.info("⚠️ No entities extracted from the response. Regenerating...") | |
| resp = self.call_completion( | |
| model="gemini/gemini-2.5-flash", | |
| messages=messages, | |
| # response_format=EntitiesList, # remove pydantic and manually parse entities | |
| reasoning_effort="low" | |
| ) | |
| entities = self.parse_entities(resp)['verify'] | |
| if not entities: | |
| logging.info("⚠️ Still no entities extracted- this is likely that the response does not contain any entities.") | |
| return json.dumps([]) | |
| else: | |
| # print(f"✨ Entities & Claims: {validated.verify}") | |
| entities = validated.model_dump()['verify'] | |
| return json.dumps(entities) | |
| except Exception as e: | |
| # logging.info(f"Output: {validated}") | |
| raise ValueError(f"Error validating response: {e}") | |
| except Exception as e: | |
| logging.error(f"Retrying ({_}/{self.max_retry}) … {e}") | |
| continue | |
| raise RuntimeError("Model failed after max_retries") | |
| def parse_entities(self, text: str): | |
| """ | |
| Parse named entities from the text. | |
| Returns a dictionary with entity types as keys and lists of entities as values. | |
| """ | |
| # while True: | |
| matches = re.match(r"```\S*\s([\s\S]+?)```\s*", text) | |
| if matches: | |
| entities = matches.group(1).strip() | |
| # print(f"Extracted entities: {entities}") | |
| while isinstance(entities, str) and entities.startswith("{") and entities.endswith("}"): | |
| try: | |
| # Attempt to parse the JSON string | |
| entities = json.loads(entities) | |
| break | |
| except json.JSONDecodeError: | |
| # If parsing fails, assume it's a string representation of a dict | |
| entities = eval(entities) | |
| else: | |
| matches = re.search(r'\{\s*"verify"\s*:\s*\[.*?\]\s*\}', text, re.S) | |
| if matches: | |
| entities = matches.group(0).strip() | |
| while isinstance(entities, str) and entities.startswith("{") and entities.endswith("}"): | |
| try: | |
| # Attempt to parse the JSON string | |
| entities = json.loads(entities) | |
| break | |
| except json.JSONDecodeError: | |
| # If parsing fails, assume it's a string representation of a dict | |
| entities = eval(entities) | |
| else: | |
| print("No valid entities found in the response.") | |
| print(f"Response: {text}") | |
| entities = {"verify":[]} | |
| # logging.info(f"Parsed entities: {entities}") | |
| return entities | |
| def call_completion(self, model="gemini/gemini-2.5-flash", | |
| messages=None, | |
| **kwargs): | |
| """ | |
| Calls the completion API with the given parameters. | |
| """ | |
| for _ in range(self.max_retry): | |
| try: | |
| resp = completion( | |
| model=model, | |
| messages=messages, | |
| **kwargs | |
| ) | |
| self.cost += completion_cost(completion_response=resp) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f"Retrying ({_}/{self.max_retry}) … {e}") | |
| continue | |