Rishabh2095's picture
Update recursion limit in job workflow and replace LLM model references with Google Gemma. Add new data loading and research subgraph modules for modularity
69e77a3
# Standard library imports
import asyncio
import logging
import os
from pathlib import Path
# Third-party imports
import dspy
from dotenv import load_dotenv
from langchain_tavily import TavilySearch
from openevals.llm import create_async_llm_as_judge
from openevals.prompts import RAG_HELPFULNESS_PROMPT, RAG_RETRIEVAL_RELEVANCE_PROMPT
# Local imports
from ..agents.output_schema import TavilySearchQueries
from ..classes.classes import ResearchState
from ..utils.llm_provider_factory import LLMFactory
logger = logging.getLogger(__name__)
env_path = Path(__file__).parent / ".env"
load_dotenv(dotenv_path=env_path, override=True)
# Safe environment variable access with validation
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not openrouter_api_key:
logger.error("OPENROUTER_API_KEY environment variable not set")
raise ValueError("OPENROUTER_API_KEY environment variable is required")
class TavilyResearchTool:
def __init__(
self,
job_description,
company_name,
max_results=5,
model_name="google/gemma-3-2google/gemma-3-27b-it:free7b-it:free",
):
# Create LLM inside __init__ (lazy initialization)
llm_provider = LLMFactory()
self.dspy_llm = llm_provider.create_dspy(
model=model_name, provider="openrouter", temperature=0.3
)
self.job_description = job_description
self.company_name = company_name
self.tavily_searchtool = TavilySearch(max_results=max_results)
def create_tavily_queries(self):
"""
Generate search queries for TavilySearch based on the job description and company name.
Returns:
dict: A dictionary containing the generated search queries.
"""
tavily_query_generator = dspy.ChainOfThought(TavilySearchQueries)
with dspy.context(lm=self.dspy_llm, adapter=dspy.JSONAdapter()):
response = tavily_query_generator(
job_description=self.job_description, company_name=self.company_name
)
return response
def tavily_search_company(self, queries):
"""
Execute Tavily searches for multiple queries.
Args:
queries: Dictionary of query identifiers to query strings
Returns:
List of search result lists, one per query
"""
query_results: list[list[str]] = []
for query_key in queries:
try:
query_string = queries.get(query_key, "")
if not query_string:
logger.warning(f"Empty query for key: {query_key}")
continue
search_query_response = self.tavily_searchtool.invoke(
{"query": query_string}
)
# Safe dictionary access for response
results = search_query_response.get("results", [])
query_results.append(
[res.get("content", "") for res in results if isinstance(res, dict)]
)
except Exception as e:
logger.error(
f"Failed to perform company research using TavilySearchTool. Error: {e}"
)
continue
return query_results
def get_relevance_evaluator():
"""
Create an LLM-as-judge evaluator for relevance filtering.
Creates the LLM on-demand (lazy initialization) to avoid startup delays.
"""
# Create LLM inside function (lazy initialization)
llm_provider = LLMFactory()
llm_structured = llm_provider.create_langchain(
"llama3.1-8b", provider="cerebras", temperature=0.3
)
return create_async_llm_as_judge(
judge=llm_structured,
prompt=RAG_RETRIEVAL_RELEVANCE_PROMPT,
feedback_key="retrieval_relevance",
)
def get_helpfulness_evaluator():
"""
Create an LLM-as-judge evaluator for helpfulness filtering.
Creates the LLM on-demand (lazy initialization) to avoid startup delays.
"""
# Create LLM inside function (lazy initialization)
llm_provider = LLMFactory()
llm_structured = llm_provider.create_langchain(
"llama3.1-8b", provider="cerebras", temperature=0.3
)
return create_async_llm_as_judge(
judge=llm_structured,
prompt=RAG_HELPFULNESS_PROMPT
+ '\nReturn "true" if the answer is helpful, and "false" otherwise.',
feedback_key="helpfulness",
)
async def filter_research_results_by_relevance(state: ResearchState) -> ResearchState:
"""
Filter search results to keep only relevant company information.
Uses LLM-as-judge to evaluate if each result set is relevant to its query.
Irrelevant results are REMOVED from the final output.
"""
try:
state["current_node"] = "filter_research_results_by_relevance"
# Extract and validate required state fields once
company_research_data = state.get("company_research_data", {})
raw_search_results = company_research_data.get("tavily_search", [])
search_queries_used = state.get("attempted_search_queries", [])
# Validate data types
if not isinstance(raw_search_results, list):
logger.warning(f"Invalid search results type: {type(raw_search_results)}")
return state
if not isinstance(search_queries_used, list):
logger.warning(f"Invalid queries type: {type(search_queries_used)}")
search_queries_used = []
# Early exit if no results
if len(raw_search_results) == 0:
logger.info("No search results to filter.")
# Update using the extracted variable
company_research_data["tavily_search"] = []
state["company_research_data"] = company_research_data
return state
logger.info(
f"Starting relevance filtering for {len(raw_search_results)} result sets..."
)
# Track filtering statistics
results_kept = []
results_removed_count = 0
evaluation_errors_count = 0
# Limit concurrent evaluations to prevent rate limiting
concurrency_limiter = asyncio.Semaphore(2)
async def evaluate_result_set_relevance(
search_result_content, original_query: str
):
"""
Evaluate if a search result set is relevant to its query.
Returns:
tuple: (search_result_content, is_relevant: bool, error: str|None)
"""
async with concurrency_limiter:
try:
# Skip empty result sets
if not search_result_content:
logger.debug(
f"Skipping empty result set for query: {original_query[:50]}..."
)
return (None, False, "empty")
# Create relevance evaluator
llm_relevance_judge = get_relevance_evaluator()
# Evaluate with timeout protection
evaluation_task = llm_relevance_judge(
inputs=original_query, context=search_result_content
)
evaluation_result = await asyncio.wait_for(
evaluation_task, timeout=15
)
# Extract relevance score (True = relevant, False = not relevant)
is_result_relevant = bool(evaluation_result.get("score", False))
if is_result_relevant:
logger.debug(
f"KEPT: Result relevant for query: {original_query[:60]}..."
)
return (search_result_content, True, None)
else:
logger.debug(
f"REMOVED: Result not relevant for query: {original_query[:60]}..."
)
return (None, False, None)
except asyncio.TimeoutError:
logger.warning(
f"Evaluation timed out for query: {original_query[:60]}... (KEEPING result)"
)
# Keep the result on timeout to avoid losing potentially useful data
return (search_result_content, True, "timeout")
except Exception as e:
logger.error(
f"Evaluation failed for query: {original_query[:60]}... - {e} (KEEPING result)"
)
return (search_result_content, True, f"error:{str(e)}")
# Create evaluation tasks for all result sets
evaluation_tasks = []
for result_set, query in zip(raw_search_results, search_queries_used):
task = evaluate_result_set_relevance(result_set, query)
evaluation_tasks.append(task)
# Execute all evaluations concurrently
all_evaluation_results = await asyncio.gather(
*evaluation_tasks, return_exceptions=True
)
# Process evaluation results and separate kept vs removed
for eval_result in all_evaluation_results:
# Handle exceptions from gather
if isinstance(eval_result, Exception):
logger.error(f"Evaluation task failed with exception: {eval_result}")
evaluation_errors_count += 1
continue
# Type guard: eval_result is now guaranteed to be a tuple
if not isinstance(eval_result, tuple) or len(eval_result) != 3:
logger.error(
f"Unexpected evaluation result format: {type(eval_result)}"
)
evaluation_errors_count += 1
continue
result_content, is_relevant, error = eval_result
# Track errors
if error:
evaluation_errors_count += 1
# Keep relevant results, discard irrelevant ones
if result_content is not None and is_relevant:
results_kept.append(result_content)
else:
results_removed_count += 1
# Update company_research_data with ONLY the relevant results
company_research_data["tavily_search"] = results_kept
state["company_research_data"] = company_research_data
# Log filtering summary
total_evaluated = len(raw_search_results)
kept_count = len(results_kept)
removed_count = results_removed_count
logger.info(
f"Relevance filtering complete: "
f"KEPT {kept_count} | REMOVED {removed_count} | TOTAL {total_evaluated} "
f"({evaluation_errors_count} evaluation errors)"
)
return state
except Exception as e:
logger.error(f"Critical error in relevance filtering: {e}", exc_info=True)
# On critical error, return original state unchanged
return state