kardwalker commited on
Commit
99e67ae
·
verified ·
1 Parent(s): 94aae9f

Create Search_Agent.py

Browse files
Files changed (1) hide show
  1. Search_Agent.py +668 -0
Search_Agent.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt - Updated compatible versions
2
+ """
3
+ langgraph>=0.2.0
4
+ langchain>=0.2.0
5
+ langchain-openai>=0.1.0
6
+ langchain-community>=0.2.0
7
+ sentence-transformers>=2.2.2
8
+ faiss-cpu>=1.7.4
9
+ googlesearch-python>=1.2.3
10
+ duckduckgo-search>=6.1.0
11
+ aiohttp>=3.9.1
12
+ beautifulsoup4>=4.12.2
13
+ redis>=5.0.1
14
+ numpy>=1.24.3
15
+ scikit-learn>=1.3.0
16
+ openai>=1.0.0
17
+ """
18
+ # 293
19
+
20
+ import asyncio
21
+ import json
22
+ import time
23
+ from datetime import datetime, timedelta
24
+ from typing import Dict, List, Optional, TypedDict, Annotated, Any, Tuple
25
+ from enum import Enum
26
+ import hashlib
27
+ import logging
28
+ import re
29
+ import urllib.parse
30
+
31
+ import numpy as np
32
+ from sentence_transformers import SentenceTransformer
33
+ import faiss
34
+ from sklearn.metrics.pairwise import cosine_similarity
35
+
36
+ from langchain_openai import AzureChatOpenAI
37
+ from langchain.schema import Document
38
+ from langgraph.graph import StateGraph, END
39
+ from langchain.tools import Tool
40
+
41
+ import aiohttp
42
+ from bs4 import BeautifulSoup
43
+ import redis
44
+ from googlesearch import search as google_search
45
+ from duckduckgo_search import DDGS # Corrected import
46
+ from dotenv import load_dotenv
47
+ import os
48
+ load_dotenv()
49
+ # Configure logging
50
+ logging.basicConfig(level=logging.INFO)
51
+ logger = logging.getLogger(__name__)
52
+
53
+ # Initialize models
54
+
55
+ llm = AzureChatOpenAI(
56
+ api_key=os.getenv("AZURE_API_KEY"),
57
+ azure_endpoint=os.getenv("Azure_endpoint"),
58
+ api_version="2024-12-01-preview",
59
+ model="gpt-4o-mini",
60
+ streaming=True,
61
+ temperature=0.8,
62
+ max_tokens=512,
63
+ azure_deployment="gpt-4o-mini", # Ensure this matches your deployment name
64
+ )
65
+
66
+ embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')
67
+
68
+ # Initialize Redis for caching (optional)
69
+ try:
70
+ redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
71
+ except:
72
+ redis_client = None
73
+ logger.warning("Redis not available, caching disabled")
74
+
75
+ # Constants
76
+ CACHE_TTL = 3600 # 1 hour
77
+ MAX_RESULTS_PER_SOURCE = 10
78
+ RATE_LIMIT_DELAY = 0.5
79
+
80
+ class QueryIntent(Enum):
81
+ FACTUAL = "factual"
82
+ NAVIGATIONAL = "navigational"
83
+ INFORMATIONAL = "informational"
84
+ TRANSACTIONAL = "transactional"
85
+ RESEARCH = "research"
86
+
87
+ class SearchResult(TypedDict):
88
+ title: str
89
+ url: str
90
+ snippet: str
91
+ source: str
92
+ timestamp: str
93
+ relevance_score: float
94
+ authority_score: float
95
+ freshness_score: float
96
+ verified: bool
97
+ content: Optional[str]
98
+
99
+ class AgentState(TypedDict):
100
+ query: str
101
+ intent: Optional[QueryIntent]
102
+ expanded_queries: List[str]
103
+ search_results: List[SearchResult]
104
+ semantic_index: Optional[Any] # FAISS index
105
+ ranked_results: List[SearchResult]
106
+ verified_facts: List[Dict[str, Any]]
107
+ answer: str
108
+ confidence_score: float
109
+ error_log: List[str]
110
+ cache_hits: int
111
+ processing_time: float
112
+ user_context: Dict[str, Any]
113
+ iteration: int
114
+
115
+ class SearchAgent:
116
+ def __init__(self):
117
+ self.memory = {}
118
+ self.user_profiles = {}
119
+
120
+ async def classify_intent(self, state: AgentState) -> AgentState:
121
+ """Classify the query intent to optimize search strategy"""
122
+ try:
123
+ prompt = f"""
124
+ Classify the following search query into one of these intents:
125
+ - FACTUAL: Looking for specific facts or data
126
+ - NAVIGATIONAL: Looking for a specific website or resource
127
+ - INFORMATIONAL: Seeking general information about a topic
128
+ - TRANSACTIONAL: Looking to perform an action or transaction
129
+ - RESEARCH: In-depth research requiring multiple sources
130
+
131
+ Query: {state['query']}
132
+
133
+ Return only the intent category.
134
+ """
135
+
136
+ response = await llm.ainvoke(prompt)
137
+ intent_str = response.content.strip().upper()
138
+ state['intent'] = QueryIntent[intent_str]
139
+
140
+ except Exception as e:
141
+ state['error_log'].append(f"Intent classification error: {str(e)}")
142
+ state['intent'] = QueryIntent.INFORMATIONAL
143
+
144
+ return state
145
+
146
+ async def expand_query(self, state: AgentState) -> AgentState:
147
+ """Expand and refine the query for better results"""
148
+ try:
149
+ prompt = f"""
150
+ Given the search query and intent, generate 3-5 expanded or related queries
151
+ that would help find comprehensive information.
152
+
153
+ Original Query: {state['query']}
154
+ Intent: {state['intent'].value if state['intent'] else 'unknown'}
155
+
156
+ Return queries as a JSON list.
157
+ """
158
+
159
+ response = await llm.ainvoke(prompt)
160
+ expanded = json.loads(response.content)
161
+ state['expanded_queries'] = [state['query']] + expanded[:4]
162
+
163
+ except Exception as e:
164
+ state['error_log'].append(f"Query expansion error: {str(e)}")
165
+ state['expanded_queries'] = [state['query']]
166
+
167
+ return state
168
+
169
+ async def _fetch_snippet(self, url: str) -> str:
170
+ """Fetch snippet from URL"""
171
+ try:
172
+ async with aiohttp.ClientSession() as session:
173
+ async with session.get(url, timeout=10) as response:
174
+ if response.status == 200:
175
+ html = await response.text()
176
+ soup = BeautifulSoup(html, 'html.parser')
177
+ # Extract meta description or first paragraph
178
+ meta_desc = soup.find('meta', attrs={'name': 'description'})
179
+ if meta_desc:
180
+ return meta_desc.get('content', '')[:300]
181
+ # Fallback to first paragraph
182
+ p = soup.find('p')
183
+ if p:
184
+ return p.get_text()[:300]
185
+ except Exception as e:
186
+ logger.error(f"Error fetching snippet from {url}: {str(e)}")
187
+ return ""
188
+
189
+ async def _fetch_content(self, url: str) -> str:
190
+ """Fetch full content from URL"""
191
+ try:
192
+ async with aiohttp.ClientSession() as session:
193
+ async with session.get(url, timeout=15) as response:
194
+ if response.status == 200:
195
+ html = await response.text()
196
+ soup = BeautifulSoup(html, 'html.parser')
197
+ # Remove script and style elements
198
+ for script in soup(["script", "style"]):
199
+ script.decompose()
200
+ # Get text content
201
+ text = soup.get_text()
202
+ # Clean up whitespace
203
+ lines = (line.strip() for line in text.splitlines())
204
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
205
+ text = '\n'.join(chunk for chunk in chunks if chunk)
206
+ return text[:5000] # Limit content length
207
+ except Exception as e:
208
+ logger.error(f"Error fetching content from {url}: {str(e)}")
209
+ return ""
210
+
211
+ def _calculate_authority(self, url: str) -> float:
212
+ """Calculate authority score based on domain"""
213
+ try:
214
+ domain = urllib.parse.urlparse(url).netloc.lower()
215
+
216
+ # High authority domains
217
+ high_authority = ['wikipedia.org', 'gov', 'edu', 'nature.com', 'ieee.org']
218
+ medium_authority = ['medium.com', 'reddit.com', 'stackoverflow.com']
219
+
220
+ if any(auth in domain for auth in high_authority):
221
+ return 0.9
222
+ elif any(auth in domain for auth in medium_authority):
223
+ return 0.6
224
+ elif domain.endswith('.org'):
225
+ return 0.7
226
+ elif domain.endswith('.com'):
227
+ return 0.5
228
+ else:
229
+ return 0.3
230
+
231
+ except Exception:
232
+ return 0.3
233
+
234
+ def _calculate_freshness(self, timestamp: str) -> float:
235
+ """Calculate freshness score based on timestamp"""
236
+ try:
237
+ time_diff = datetime.now() - datetime.fromisoformat(timestamp)
238
+ days_old = time_diff.days
239
+
240
+ if days_old <= 1:
241
+ return 1.0
242
+ elif days_old <= 7:
243
+ return 0.8
244
+ elif days_old <= 30:
245
+ return 0.6
246
+ elif days_old <= 90:
247
+ return 0.4
248
+ else:
249
+ return 0.2
250
+
251
+ except Exception:
252
+ return 0.5
253
+
254
+ async def search_google(self, query: str) -> List[SearchResult]:
255
+ """Search using Google"""
256
+ results = []
257
+ try:
258
+ # Check cache first
259
+ cache_key = f"google:{hashlib.md5(query.encode()).hexdigest()}"
260
+ if redis_client:
261
+ cached = redis_client.get(cache_key)
262
+ if cached:
263
+ return json.loads(cached)
264
+
265
+ # Rate limiting
266
+ await asyncio.sleep(RATE_LIMIT_DELAY)
267
+
268
+ for i, url in enumerate(google_search(query, num_results=MAX_RESULTS_PER_SOURCE)):
269
+ if i >= MAX_RESULTS_PER_SOURCE:
270
+ break
271
+
272
+ # Fetch snippet
273
+ snippet = await self._fetch_snippet(url)
274
+
275
+ result = SearchResult(
276
+ title=url.split('/')[2] if len(url.split('/')) > 2 else url,
277
+ url=url,
278
+ snippet=snippet,
279
+ source="google",
280
+ timestamp=datetime.now().isoformat(),
281
+ relevance_score=0.0,
282
+ authority_score=0.0,
283
+ freshness_score=0.0,
284
+ verified=False,
285
+ content=None
286
+ )
287
+ results.append(result)
288
+
289
+ # Cache results
290
+ if redis_client and results:
291
+ redis_client.setex(cache_key, CACHE_TTL, json.dumps(results))
292
+
293
+ except Exception as e:
294
+ logger.error(f"Google search error: {str(e)}")
295
+
296
+ return results
297
+
298
+ async def search_duckduckgo(self, query: str) -> List[SearchResult]:
299
+ """Search using DuckDuckGo"""
300
+ results = []
301
+ try:
302
+ # Check cache
303
+ cache_key = f"ddg:{hashlib.md5(query.encode()).hexdigest()}"
304
+ if redis_client:
305
+ cached = redis_client.get(cache_key)
306
+ if cached:
307
+ return json.loads(cached)
308
+
309
+ async with DDGS() as ddgs: # Use DDGS directly in async with
310
+ search_results = await ddgs.text(query, max_results=MAX_RESULTS_PER_SOURCE)
311
+
312
+ for r in search_results:
313
+ result = SearchResult(
314
+ title=r.get('title', ''),
315
+ url=r.get('href', ''),
316
+ snippet=r.get('body', ''),
317
+ source="duckduckgo",
318
+ timestamp=datetime.now().isoformat(),
319
+ relevance_score=0.0,
320
+ authority_score=0.0,
321
+ freshness_score=0.0,
322
+ verified=False,
323
+ content=None
324
+ )
325
+ results.append(result)
326
+
327
+ # Cache results
328
+ if redis_client and results:
329
+ redis_client.setex(cache_key, CACHE_TTL, json.dumps(results))
330
+
331
+ except Exception as e:
332
+ logger.error(f"DuckDuckGo search error: {str(e)}")
333
+
334
+ return results
335
+
336
+ async def parallel_search(self, state: AgentState) -> AgentState:
337
+ """Execute parallel searches across multiple sources"""
338
+ all_results = []
339
+
340
+ for query in state['expanded_queries']:
341
+ # Create search tasks
342
+ tasks = [
343
+ self.search_google(query),
344
+ self.search_duckduckgo(query),
345
+ ]
346
+
347
+ # Execute in parallel
348
+ results = await asyncio.gather(*tasks, return_exceptions=True)
349
+
350
+ # Combine results
351
+ for result_set in results:
352
+ if isinstance(result_set, list):
353
+ all_results.extend(result_set)
354
+
355
+ # Remove duplicates based on URL
356
+ seen_urls = set()
357
+ unique_results = []
358
+ for result in all_results:
359
+ if result['url'] not in seen_urls:
360
+ seen_urls.add(result['url'])
361
+ unique_results.append(result)
362
+
363
+ state['search_results'] = unique_results
364
+ return state
365
+
366
+ def create_semantic_index(self, state: AgentState) -> AgentState:
367
+ """Create FAISS index for semantic search"""
368
+ try:
369
+ if not state['search_results']:
370
+ return state
371
+
372
+ # Extract text for embedding
373
+ texts = [f"{r['title']} {r['snippet']}" for r in state['search_results']]
374
+
375
+ # Generate embeddings
376
+ embeddings = embeddings_model.encode(texts)
377
+
378
+ # Create FAISS index
379
+ dimension = embeddings.shape[1]
380
+ index = faiss.IndexFlatL2(dimension)
381
+ index.add(np.array(embeddings).astype('float32'))
382
+
383
+ state['semantic_index'] = {
384
+ 'index': index,
385
+ 'embeddings': embeddings,
386
+ 'texts': texts
387
+ }
388
+
389
+ except Exception as e:
390
+ state['error_log'].append(f"Semantic index creation error: {str(e)}")
391
+
392
+ return state
393
+
394
+ def calculate_scores(self, state: AgentState) -> AgentState:
395
+ """Calculate relevance, authority, and freshness scores"""
396
+ try:
397
+ query_embedding = embeddings_model.encode([state['query']])[0]
398
+
399
+ for i, result in enumerate(state['search_results']):
400
+ # Relevance score (semantic similarity)
401
+ if state.get('semantic_index') and i < len(state['semantic_index']['embeddings']):
402
+ result_embedding = state['semantic_index']['embeddings'][i]
403
+ relevance = cosine_similarity(
404
+ [query_embedding],
405
+ [result_embedding]
406
+ )[0][0]
407
+ result['relevance_score'] = float(relevance)
408
+
409
+ # Authority score (based on domain and source)
410
+ authority = self._calculate_authority(result['url'])
411
+ result['authority_score'] = authority
412
+
413
+ # Freshness score
414
+ freshness = self._calculate_freshness(result['timestamp'])
415
+ result['freshness_score'] = freshness
416
+
417
+ except Exception as e:
418
+ state['error_log'].append(f"Score calculation error: {str(e)}")
419
+
420
+ return state
421
+
422
+ def rank_results(self, state: AgentState) -> AgentState:
423
+ """Rank results using multiple factors"""
424
+ try:
425
+ # Calculate composite scores
426
+ for result in state['search_results']:
427
+ result['composite_score'] = (
428
+ 0.5 * result.get('relevance_score', 0) +
429
+ 0.3 * result.get('authority_score', 0) +
430
+ 0.2 * result.get('freshness_score', 0)
431
+ )
432
+
433
+ # Sort by composite score
434
+ state['ranked_results'] = sorted(
435
+ state['search_results'],
436
+ key=lambda x: x.get('composite_score', 0),
437
+ reverse=True
438
+ )[:20] # Top 20 results
439
+
440
+ except Exception as e:
441
+ state['error_log'].append(f"Ranking error: {str(e)}")
442
+ state['ranked_results'] = state['search_results'][:20]
443
+
444
+ return state
445
+
446
+ async def _extract_and_verify_facts(self, content: str, query: str) -> List[Dict[str, Any]]:
447
+ """Extract and verify facts from content"""
448
+ try:
449
+ prompt = f"""
450
+ Extract key facts from the following content that are relevant to the query: "{query}"
451
+
452
+ Content: {content[:2000]}
453
+
454
+ Return a JSON list of facts with their confidence scores (0-1).
455
+ Format: [{{"fact": "fact statement", "confidence": 0.95}}]
456
+ """
457
+
458
+ response = await llm.ainvoke(prompt)
459
+ facts = json.loads(response.content)
460
+ return facts
461
+
462
+ except Exception as e:
463
+ logger.error(f"Fact extraction error: {str(e)}")
464
+ return []
465
+
466
+ async def verify_facts(self, state: AgentState) -> AgentState:
467
+ """Verify facts from top results"""
468
+ try:
469
+ # Extract potential facts from top results
470
+ top_results = state['ranked_results'][:5]
471
+
472
+ facts = []
473
+ for result in top_results:
474
+ # Fetch full content if needed
475
+ if not result.get('content'):
476
+ result['content'] = await self._fetch_content(result['url'])
477
+
478
+ # Extract and verify facts
479
+ if result['content']:
480
+ verified_facts = await self._extract_and_verify_facts(
481
+ result['content'],
482
+ state['query']
483
+ )
484
+ facts.extend(verified_facts)
485
+
486
+ state['verified_facts'] = facts
487
+
488
+ except Exception as e:
489
+ state['error_log'].append(f"Fact verification error: {str(e)}")
490
+ state['verified_facts'] = []
491
+
492
+ return state
493
+
494
+ async def generate_answer(self, state: AgentState) -> AgentState:
495
+ """Generate final answer with confidence score"""
496
+ try:
497
+ # Prepare context from top results
498
+ context = "\n\n".join([
499
+ f"Source: {r['url']}\nTitle: {r['title']}\nContent: {r['snippet']}"
500
+ for r in state['ranked_results'][:5]
501
+ ])
502
+
503
+ # Include verified facts
504
+ facts_context = "\n".join([
505
+ f"Verified Fact: {f['fact']} (Confidence: {f['confidence']})"
506
+ for f in state.get('verified_facts', [])
507
+ ])
508
+
509
+ prompt = f"""
510
+ Based on the search results and verified facts, provide a comprehensive answer to the query.
511
+ Include source citations and indicate confidence level.
512
+
513
+ Query: {state['query']}
514
+ Intent: {state['intent'].value if state['intent'] else 'unknown'}
515
+
516
+ Search Results:
517
+ {context}
518
+
519
+ Verified Facts:
520
+ {facts_context}
521
+
522
+ Provide a detailed answer with source citations. End with a confidence score (0-1).
523
+ """
524
+
525
+ response = await llm.ainvoke(prompt)
526
+ answer = response.content
527
+
528
+ # Extract confidence score from answer
529
+ confidence_match = re.search(r'confidence.*?(\d+\.?\d*)', answer.lower())
530
+ if confidence_match:
531
+ state['confidence_score'] = float(confidence_match.group(1))
532
+ else:
533
+ state['confidence_score'] = 0.7 # Default confidence
534
+
535
+ state['answer'] = answer
536
+
537
+ except Exception as e:
538
+ state['error_log'].append(f"Answer generation error: {str(e)}")
539
+ state['answer'] = "I apologize, but I encountered an error while generating the answer."
540
+ state['confidence_score'] = 0.0
541
+
542
+ return state
543
+
544
+ def should_expand_query(self, state: AgentState) -> str:
545
+ """Determine if query should be expanded"""
546
+ if len(state['query'].split()) <= 2:
547
+ return "expand"
548
+ return "search"
549
+
550
+ def should_create_index(self, state: AgentState) -> str:
551
+ """Determine if semantic index should be created"""
552
+ if len(state['search_results']) > 5:
553
+ return "create_index"
554
+ return "calculate_scores"
555
+
556
+ def should_verify_facts(self, state: AgentState) -> str:
557
+ """Determine if facts should be verified"""
558
+ if state['intent'] == QueryIntent.FACTUAL and len(state['ranked_results']) > 0:
559
+ return "verify"
560
+ return "generate_answer"
561
+
562
+ def create_search_workflow() -> StateGraph:
563
+ """Create the search workflow graph"""
564
+ agent = SearchAgent()
565
+
566
+ # Create the graph
567
+ workflow = StateGraph(AgentState)
568
+
569
+ # Add nodes
570
+ workflow.add_node("classify_intent", agent.classify_intent)
571
+ workflow.add_node("expand_query", agent.expand_query)
572
+ workflow.add_node("parallel_search", agent.parallel_search)
573
+ workflow.add_node("create_semantic_index", agent.create_semantic_index)
574
+ workflow.add_node("calculate_scores", agent.calculate_scores)
575
+ workflow.add_node("rank_results", agent.rank_results)
576
+ workflow.add_node("verify_facts", agent.verify_facts)
577
+ workflow.add_node("generate_answer", agent.generate_answer)
578
+
579
+ # Set entry point
580
+ workflow.set_entry_point("classify_intent")
581
+
582
+ # Add edges
583
+ workflow.add_edge("classify_intent", "expand_query")
584
+ workflow.add_edge("expand_query", "parallel_search")
585
+
586
+ # Conditional routing
587
+ workflow.add_conditional_edges(
588
+ "parallel_search",
589
+ agent.should_create_index,
590
+ {
591
+ "create_index": "create_semantic_index",
592
+ "calculate_scores": "calculate_scores"
593
+ }
594
+ )
595
+
596
+ workflow.add_edge("create_semantic_index", "calculate_scores")
597
+ workflow.add_edge("calculate_scores", "rank_results")
598
+
599
+ workflow.add_conditional_edges(
600
+ "rank_results",
601
+ agent.should_verify_facts,
602
+ {
603
+ "verify": "verify_facts",
604
+ "generate_answer": "generate_answer"
605
+ }
606
+ )
607
+
608
+ workflow.add_edge("verify_facts", "generate_answer")
609
+ workflow.add_edge("generate_answer", END)
610
+
611
+ return workflow.compile()
612
+
613
+ # Usage example
614
+ async def main():
615
+ """Main function to demonstrate the search agent"""
616
+ # Create the workflow
617
+ app = create_search_workflow()
618
+
619
+ # Initialize state
620
+ initial_state = AgentState(
621
+ query=input("Enter your query: "),
622
+ intent=None,
623
+ expanded_queries=[],
624
+ search_results=[],
625
+ semantic_index=None,
626
+ ranked_results=[],
627
+ verified_facts=[],
628
+ answer="",
629
+ confidence_score=0.0,
630
+ error_log=[],
631
+ cache_hits=0,
632
+ processing_time=0.0,
633
+ user_context={},
634
+ iteration=0
635
+ )
636
+
637
+ # Run the workflow
638
+ start_time = time.time()
639
+
640
+ try:
641
+ final_state = await app.ainvoke(initial_state)
642
+
643
+ # Calculate processing time
644
+ processing_time = time.time() - start_time
645
+
646
+ # Print results
647
+ print(f"Query: {final_state['query']}")
648
+ print(f"Intent: {final_state['intent']}")
649
+ print(f"Expanded Queries: {final_state['expanded_queries']}")
650
+ print(f"Total Results: {len(final_state['search_results'])}")
651
+ print(f"Top Results: {len(final_state['ranked_results'])}")
652
+ print(f"Verified Facts: {len(final_state['verified_facts'])}")
653
+ print(f"Confidence Score: {final_state['confidence_score']}")
654
+ print(f"Processing Time: {processing_time:.2f}s")
655
+
656
+ if final_state['error_log']:
657
+ print(f"Errors: {final_state['error_log']}")
658
+
659
+ print(f"\nAnswer:\n{final_state['answer']}")
660
+
661
+ except Exception as e:
662
+ print(f"Error running workflow: {str(e)}")
663
+
664
+ # Directly await main() instead of using asyncio.run()
665
+ if __name__ == "__main__":
666
+ import nest_asyncio
667
+ nest_asyncio.apply()
668
+ asyncio.run(main()) # Keep asyncio.run() but apply nest_asyncio for Colab compatibilityh