File size: 20,386 Bytes
a91cc9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
from agent import Agent
import os
import json
from litellm import completion
from tools.web_search import GoogleClaimSearch ## NOTE: (optional) custom tool for web search
from tools.address_locator import GoogleGeocodeValidate ## NOTE: (optional) custom tool for address validation
from typing import List, Dict
import re
from utils import parse_output
import time
from pydantic import BaseModel, Field
import logging
import litellm
from litellm import completion, completion_cost
litellm.drop_params = True

class VerificationResult(BaseModel):
    verify: str
    evidence: List[str]
    result: str


class Entities(BaseModel):
    entity: str
    claim: str
    
class EntitiesList(BaseModel):
    verify: list[Entities]


class SimpleAgent(Agent):
    def __init__(self, 
                 model, 
                 name, 
                 description, 
                 keep_history=True, 
                 is_local=False, 
                 **gen_kwargs):
        
        super().__init__(model, name, description)
        self.keep_history = keep_history
        self.is_local = is_local
        self.gen_kwargs = gen_kwargs
        if self.is_local:
            self.model = "openai/" + self.model
            self.gen_kwargs['base_url'] = "http://localhost:8000/v1"
            self.gen_kwargs['extra_body'] = {
                    "thinking_budget":512
                }
        
    def chat(self, prompt) -> tuple[str, str]:
        self.store_chat("user", prompt)
        for _ in range(self.max_retry):
            try:
                res = completion(
                    model=self.model,
                    messages=self.history,
                    **self.gen_kwargs
                )
                res_json = res.choices[0].message.model_dump()
                text = res_json['content'].strip()
                if not self.is_local:
                    self.cost += completion_cost(completion_response=res)
                
                if 'reasoning_content' in res_json and res_json['reasoning_content']:
                    reasoning = res_json['reasoning_content']
                else:
                    reasoning = None
                self.store_chat("assistant", text, reasoning)
                if not self.keep_history:
                    self.history = [self.history[0]]
                
                return text, reasoning
            except Exception as e:
                logging.exception(f"Retrying ({_}/{self.max_retry}) … {e}")
        raise RuntimeError("Model failed after max_retries")



class WebSearchAgent(Agent):
    def __init__(self, model, name, description, reasoning_effort="disable"):
        super().__init__(model, name, description)
        self.reasoning_effort = reasoning_effort
        self.tools_map = {
            "google_claim_search":GoogleClaimSearch(
                api_key=os.environ["CUSTOM_SEARCH_API_KEY"],
                cx=os.environ["GOOGLE_CX_ID"]
            ),
            "google_geocode_validate": GoogleGeocodeValidate(
                api_key=os.environ["GOOGLE_MAP_API_KEY"]
        )
        }

        self.tools = [tool.get_info() for tool in self.tools_map.values()]
        self.max_retry = 3
        self.tool_cost = 0.0
        self.tool_calls_count = {
            "google_claim_search": 0,
            "google_geocode_validate": 0
        }

    def chat(self, entity_dicts: list, std_date: str): # only works for google search tool
        for _ in range(self.max_retry):
            try:
                stored_messages = [self.history[0]] # system message
                json_results = []
                for entity_dict in entity_dicts:
                    assert entity_dict, "Entity dictionary is empty."
                    prompt = f"Claim: {entity_dict['claim']}\nEntity: {entity_dict['entity']}\nCutoff date: {std_date}"
                    stored_messages.append({"role": "user", "content": prompt})
                    
                    stored_messages, result = self._tool_call(stored_messages)
                    # logging.info(f"<DEBUG>: `result` type - {type(result)}")
                    if isinstance(result, str):
                        if self._extract_dict_from_string(result):
                            # skip verification if the result is already in the expected output format
                            result_dict = json.loads(result, strict=False) # parse the result
                            v, e, r = result_dict.get('verify', None), result_dict.get('evidence', None), result_dict.get('result', None)
                            if v is None or e is None or r is None:
                                raise ValueError(f"Invalid response format from model: {result}")
                        else:
                            # print(f"Unexpected result format: {result}")
                            raise ValueError(f"Unexpected result format: {result}")
                    elif isinstance(result, list): # tool call result
                        if not result:
                            raise ValueError("Empty result list returned from tool call.")
                        text = self._verify(stored_messages)
                        v, e, r = VerificationResult.model_validate_json(text).model_dump().values()
                    else:
                        raise ValueError(f"Unexpected result type and value: {type(result)}. Expected str or list.")
                    
                    if v is None or r is None or e is None:
                        raise ValueError(f"Invalid response format from model: {result}")
                    
                    tool_name = stored_messages[-1].get('name', '')
                    result_summary = {}
                    if tool_name == "google_claim_search":
                        result_summary['tool'] = "google_claim_search"
                        result_summary['search_results'] = []
                        for sr in result:
                            if not sr:
                                continue
                            result_summary['search_results'].append({"title": sr["title"], "link": sr["link"]})

                    elif tool_name == "google_geocode_validate":
                        result_summary['tool'] = "google_geocode_validate"
                        result_summary['search_results'] = result
                    else: # possibly a immediate response from the model
                        result_summary['tool'] = "none"
                        result_summary['search_results'] = ["none"]

                    json_result = {
                        "entity": entity_dict["entity"],
                        "claim": entity_dict['claim'],
                        "search_result": result_summary,
                        "verification": v,
                        "evidence": e,
                        "result": r
                    }
                    
                    stored_messages = [self.history[0]] # reset the history for the next entity
                    json_results.append(json_result)
                    
                return json.dumps(json_results)
            except Exception as e:
                logging.exception(f"Retrying ({_+1}/{self.max_retry}) ... {e}")

        raise RuntimeError("Model failed after max_retries")
        

    def _tool_call(self, stored_messages: list):
        # stored_messages = [self.history[0], {"role": "user", "content": prompt}]
        kwargs = {
            "tool_choice": "auto", # {'type':'function', 'function': {'name': self.tools[0]['function']['name']}},
            "tools": self.tools,
            "reasoning_effort": self.reasoning_effort
        }
        res = completion(
                model=self.model,
                messages=stored_messages, # system message and the last user message (current input)
                **kwargs
            )
        
        self.cost += completion_cost(completion_response=res)

        message = res.choices[0].message.model_dump()
        # print(message)
        if 'tool_calls' in message and message['tool_calls']:
            tool_call = message['tool_calls'][0]
            tool = tool_call["function"]["name"]
            kwargs = json.loads(tool_call["function"]["arguments"])
            # print(f"Tool call arguments: {kwargs}")
            
            result = self.tools_map[tool].invoke(**kwargs)
            ### cost and tool call count ###
            if tool == "google_claim_search":
                tmp_result = json.loads(result)
                for item in tmp_result:
                    if not isinstance(item['text_block'], str) or not item['text_block'].startswith("Search failure"):
                        self.tool_calls_count["google_claim_search"] += 1
                        self.tool_cost += self._tool_call_pricing("google_claim_search")
            elif tool == "google_geocode_validate":
                tmp_result = json.loads(result)[0]
                if any(list(tmp_result.values())): # error: all values are None
                    self.tool_calls_count["google_geocode_validate"] += 1
                    self.tool_cost += self._tool_call_pricing("google_geocode_validate")
            time.sleep(0.5)  # to avoid rate limit issues
            stored_messages.extend(
                [
                    message,
                    {
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": tool_call['function']["name"],
                        "content": result
                    }
                ]
            )
            return stored_messages, json.loads(result)
        elif message["content"].strip():
            text = message["content"].strip()
            stored_messages.append({"role": "assistant", "content": text})
            # logging.info(f"Model response without tool call: {text}")
            return stored_messages, text
        else:
            raise ValueError("No tool call found in the response from the model.")
    
    def _extract_dict_from_string(self, input_string):
        start_index = input_string.find('{')
        end_index = input_string.rfind('}')

        if start_index != -1 and end_index != -1 and start_index < end_index:
            return input_string[start_index:end_index + 1]
        else:
            return None
        
    def claim_search(self, claim: str): ### benchmark evaluation 
        prompt = f"Claim: {claim}"
        for _ in range(self.max_retry):
            try:
                stored_messages = [self.history[0]]
                stored_messages.append({"role": "user", "content": prompt})
                
                stored_messages, result = self._tool_call(stored_messages)
                # print(f"[result]: {result}")
                if isinstance(result, str):
                    if '<verify>' in result:
                        # skip verification if the result is already in the expected output format
                        v, e, r = self._parse_result(result)
                    else:
                        # print(f"Unexpected result format: {result}")
                        raise ValueError(f"Unexpected result format: {result}")
                elif isinstance(result, list): # tool call result
                    if not result:
                        raise ValueError("Empty result list returned from tool call.")
                    text = self._verify(stored_messages)
                    v, e, r = VerificationResult.model_validate(text).model_dump().values()
                else:
                    raise ValueError(f"Unexpected result type: {type(result)}. Expected str or list.")
                
                if v is None or r is None or e is None:
                    raise ValueError(f"Invalid response format from model: {text}")
                
                
                tool_name = stored_messages[-1].get('name', '')
                result_summary = {}
                if tool_name == "google_claim_search":
                    result_summary['tool'] = "google_claim_search"
                    result_summary['search_results'] = []
                    for sr in result:
                        if not sr:
                            continue
                        result_summary['search_results'].append({"title": sr["title"], "link": sr["link"]})

                elif tool_name == "google_geocode_validate":
                    result_summary['tool'] = "google_geocode_validate"
                    result_summary['search_results'] = result
                else: # possibly a immediate response from the model
                    result_summary['tool'] = "none"
                    result_summary['search_results'] = ["none"]
                    
                result = {
                    "claim": claim,
                    "verification": v,
                    "evidence": e,
                    "result": r
                }
                
                return result

            except Exception as e:
                logging.error(f"Retrying ({_+1}/{self.max_retry}) … {e}")
        raise RuntimeError("Model failed after max_retries")

    def _verify(self, messages):
        
        res = completion(
            model=self.model,
            messages=messages,
            reasoning_effort=self.reasoning_effort,
            response_format=VerificationResult,
            tool_choice="none"
        )
        self.cost += completion_cost(completion_response=res)
        
        res_json = res.choices[0].message.model_dump()
        text = res_json["content"].strip()
        return text
    
    def _parse_result(self, text: str) -> tuple:
        matches = re.match(r"<verify>([\s\S]+?)</verify>\s*<evidence>([\s\S]+?)</evidence>\s*<result>([\s\S]+?)</result>", text)
        if not matches:
            return None, None, None
        verify = matches.group(1).strip()
        evidence = matches.group(2).strip()
        result = matches.group(3).strip()
        return verify, evidence, result
    
    def _tool_call_pricing(self, tool: str) -> float: 
        """
        single request pricing for Google Custom Search JSON API and Google Geocoding API.
        
        Custom Search JSON API provides 100 search queries per day for free. If you need more, you may sign up for billing in the API Console. 
        Additional requests cost $5 per 1000 queries, up to 10k queries per day.
        """
        if tool == "google_claim_search":
            if self.tool_calls_count["google_claim_search"] <= 100:
                return 0.0
            else: # NOTE: daily limit is 10k queries!
                return  5.0 / 1000
            
        elif tool == "google_geocode_validate":
            if self.tool_calls_count["google_geocode_validate"] <= 10000:
                return 0.0
            elif self.tool_calls_count["google_geocode_validate"] > 10000 and self.tool_calls_count["google_geocode_validate"] <= 100000:
                return 5.0 / 1000
            elif self.tool_calls_count["google_geocode_validate"] > 100000 and self.tool_calls_count["google_geocode_validate"] <= 500000:
                return 4.0 / 1000
            elif self.tool_calls_count["google_geocode_validate"] > 500000 and self.tool_calls_count["google_geocode_validate"] <= 1000000:
                return 3.0 / 1000
            elif self.tool_calls_count["google_geocode_validate"] > 1000000 and self.tool_calls_count["google_geocode_validate"] <= 5000000:
                return 1.5 / 1000
            else:
                return 0.38 / 1000


class EntityExtractor(Agent):
    def __init__(self, model, name, description, reasoning_effort="disable"):
        super().__init__(model, name, description)
        self.reasoning_effort = reasoning_effort
        self.input_format = "Question: {question}\nResponse: {response}"

    def chat(self, question: str, response: str):
        # self.store_chat("user", prompt)
        prompt = self.input_format.format(question=question, response=response)
        for _ in range(self.max_retry):
            try:
                messages = [self.history[0], {"role": "user", "content": prompt}]  # system message
                res = self.call_completion(
                    model=self.model,
                    messages=messages,
                    response_format=EntitiesList,
                    reasoning_effort=self.reasoning_effort,
                )
                try:
                    validated = EntitiesList.model_validate_json(res)
                    if not validated.verify:
                        logging.info("⚠️ No entities extracted from the response. Regenerating...")
                        resp = self.call_completion(
                            model="gemini/gemini-2.5-flash",
                            messages=messages,
                            # response_format=EntitiesList, # remove pydantic and manually parse entities
                            reasoning_effort="low"
                        )

                        entities = self.parse_entities(resp)['verify']
                        if not entities:
                            logging.info("⚠️ Still no entities extracted- this is likely that the response does not contain any entities.")
                            return json.dumps([])
                    else:
                        # print(f"✨ Entities & Claims: {validated.verify}")
                        entities = validated.model_dump()['verify']
                    return json.dumps(entities)
                except Exception as e:
                    # logging.info(f"Output: {validated}")
                    raise ValueError(f"Error validating response: {e}")
            except Exception as e:
                logging.error(f"Retrying ({_}/{self.max_retry}) … {e}")
                continue
        raise RuntimeError("Model failed after max_retries")

    def parse_entities(self, text: str):
        """
        Parse named entities from the text.
        Returns a dictionary with entity types as keys and lists of entities as values.
        """
        # while True:
        matches = re.match(r"```\S*\s([\s\S]+?)```\s*", text)
        if matches:
            entities = matches.group(1).strip()
            # print(f"Extracted entities: {entities}")
            while isinstance(entities, str) and entities.startswith("{") and entities.endswith("}"):
                try:
                    # Attempt to parse the JSON string
                    entities = json.loads(entities)
                    break
                except json.JSONDecodeError:
                    # If parsing fails, assume it's a string representation of a dict
                    entities = eval(entities)

        else:
            matches = re.search(r'\{\s*"verify"\s*:\s*\[.*?\]\s*\}', text, re.S)
            if matches:
                entities = matches.group(0).strip()
                while isinstance(entities, str) and entities.startswith("{") and entities.endswith("}"):
                    try:
                        # Attempt to parse the JSON string
                        entities = json.loads(entities)
                        break
                    except json.JSONDecodeError:
                        # If parsing fails, assume it's a string representation of a dict
                        entities = eval(entities)
            else:
                print("No valid entities found in the response.")
                print(f"Response: {text}")
                entities = {"verify":[]}
                
        # logging.info(f"Parsed entities: {entities}")
        return entities
    
    def call_completion(self, model="gemini/gemini-2.5-flash",
                        messages=None,
                        **kwargs):
        """
        Calls the completion API with the given parameters.
        """
        for _ in range(self.max_retry):
            try:
                resp = completion(
                    model=model,
                    messages=messages,
                    **kwargs
                )
                self.cost += completion_cost(completion_response=resp)
                return resp.choices[0].message.content.strip()        
            except Exception as e:
                print(f"Retrying ({_}/{self.max_retry}) … {e}")
                continue