| |
| """Research workflow for company information gathering and filtering.""" |
|
|
| |
| import asyncio |
| import json |
| import logging |
| from typing import Any, Dict, cast |
|
|
| |
| import dspy |
| from langgraph.graph import StateGraph |
|
|
| |
| from job_writing_agent.agents.output_schema import ( |
| CompanyResearchDataSummarizationSchema, |
| ) |
| from job_writing_agent.classes.classes import ResearchState, CompanyResearchData |
| from job_writing_agent.tools.SearchTool import ( |
| TavilyResearchTool, |
| filter_research_results_by_relevance, |
| ) |
| from job_writing_agent.utils.llm_provider_factory import LLMFactory |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| MAX_RETRIES = 3 |
| RETRY_DELAY = 2 |
| QUERY_TIMEOUT = 30 |
| EVAL_TIMEOUT = 15 |
|
|
|
|
| def validate_research_inputs(state: ResearchState) -> tuple[bool, str, str]: |
| """ |
| Validate that required inputs are present in research state. |
| |
| Args: |
| state: Current research workflow state |
| |
| Returns: |
| Tuple of (is_valid, company_name, job_description) |
| """ |
| try: |
| |
| company_research_data = state.company_research_data or CompanyResearchData() |
| company_name = company_research_data.company_name |
| job_description = company_research_data.job_description |
|
|
| if not company_name or not company_name.strip(): |
| logger.error("Company name is missing or empty") |
| return False, "", "" |
|
|
| if not job_description or not job_description.strip(): |
| logger.error("Job description is missing or empty") |
| return False, "", "" |
|
|
| return True, company_name.strip(), job_description.strip() |
|
|
| except (TypeError, AttributeError) as e: |
| logger.error(f"Invalid state structure: {e}") |
| return False, "", "" |
|
|
|
|
| def parse_dspy_queries_with_fallback( |
| raw_queries: dict[str, Any], company_name: str |
| ) -> dict[str, str]: |
| """ |
| Parse DSPy query output with multiple fallback strategies. |
| Returns a dict of query_id -> query_string. |
| """ |
| try: |
| |
| if isinstance(raw_queries, dict) and "search_queries" in raw_queries: |
| queries_data = raw_queries["search_queries"] |
|
|
| |
| if isinstance(queries_data, str): |
| try: |
| queries_data = json.loads(queries_data) |
| except json.JSONDecodeError as e: |
| logger.warning(f"JSON decode failed: {e}. Using fallback queries.") |
| return get_fallback_queries(company_name) |
|
|
| |
| if isinstance(queries_data, dict): |
| parsed = {} |
| for key, value in queries_data.items(): |
| if isinstance(value, str): |
| parsed[key] = value |
| elif isinstance(value, list) and len(value) > 0: |
| parsed[key] = str(value[0]) |
|
|
| if parsed: |
| return parsed |
|
|
| |
| logger.warning("Could not parse DSPy queries. Using fallback.") |
| return get_fallback_queries(company_name) |
|
|
| except Exception as e: |
| logger.error(f"Error parsing DSPy queries: {e}. Using fallback.") |
| return get_fallback_queries(company_name) |
|
|
|
|
| def get_fallback_queries(company_name: str) -> dict[str, str]: |
| """ |
| Generate basic fallback queries when DSPy fails. |
| """ |
| return { |
| "query1": f"{company_name} company culture and values", |
| "query2": f"{company_name} recent news and achievements", |
| "query3": f"{company_name} mission statement and goals", |
| } |
|
|
|
|
| def company_research_data_summary(state: ResearchState) -> dict[str, Any]: |
| """ |
| Summarize the filtered research data into a concise summary. |
| |
| Replaces the raw tavily_search results with a summarized version using LLM. |
| |
| Args: |
| state: Current research state with search results |
| |
| Returns: |
| Updated state with research summary |
| """ |
| try: |
| |
| updated_state = { |
| **state.__dict__, |
| "current_node": "company_research_data_summary", |
| } |
|
|
| |
| company_research_data = state.company_research_data or CompanyResearchData() |
| tavily_search_data = company_research_data.tavily_search |
|
|
| |
| if not tavily_search_data or len(tavily_search_data) == 0: |
| logger.warning("No research data to summarize. Skipping summarization.") |
| return updated_state |
|
|
| logger.info(f"Summarizing {len(tavily_search_data)} research result sets...") |
|
|
| |
| company_research_data_summarization = dspy.ChainOfThought( |
| CompanyResearchDataSummarizationSchema |
| ) |
|
|
| |
|
|
| llm_provider = LLMFactory() |
| llm = llm_provider.create_dspy( |
| model="openai/gpt-oss-20b:free", |
| provider="openrouter", |
| temperature=0.3, |
| ) |
|
|
| |
| with dspy.context(lm=llm, adapter=dspy.JSONAdapter()): |
| response = company_research_data_summarization( |
| company_research_data=company_research_data |
| ) |
| |
| summary_json_str = "" |
| if hasattr(response, "company_research_data_summary"): |
| summary_json_str = response.company_research_data_summary |
| elif isinstance(response, dict): |
| summary_json_str = response.get("company_research_data_summary", "") |
| else: |
| logger.error( |
| f"Unexpected response format from summarization: {type(response)}" |
| ) |
| return updated_state |
|
|
| |
| updated_company_research_data = {**company_research_data.__dict__} |
| updated_company_research_data["company_research_data_summary"] = ( |
| summary_json_str |
| ) |
| updated_state["company_research_data"] = CompanyResearchData( |
| **updated_company_research_data |
| ) |
|
|
| return updated_state |
|
|
| except Exception as e: |
| logger.error(f"Error in company_research_data_summary: {e}", exc_info=True) |
| |
| return {"current_node": "company_research_data_summary"} |
|
|
|
|
| async def research_company_with_retry(state: ResearchState) -> dict[str, Any]: |
| """ |
| Research company with retry logic and timeouts. |
| """ |
| state.current_node = "research_company" |
|
|
| |
| is_valid, company_name, job_description = validate_research_inputs(state) |
|
|
| if not is_valid: |
| logger.error("Invalid inputs for research. Skipping research phase.") |
| cr = state.company_research_data or CompanyResearchData() |
| return { |
| "company_research_data": cr.model_copy(update={"tavily_search": []}), |
| "attempted_search_queries": [], |
| "current_node": "research_company", |
| } |
|
|
| logger.info(f"Researching company: {company_name}") |
|
|
| |
| for attempt in range(MAX_RETRIES): |
| try: |
| |
| tavily_search = TavilyResearchTool( |
| job_description=job_description, company_name=company_name |
| ) |
|
|
| |
| queries_task = asyncio.create_task( |
| asyncio.to_thread(tavily_search.create_tavily_queries) |
| ) |
|
|
| try: |
| raw_queries = await asyncio.wait_for( |
| queries_task, timeout=QUERY_TIMEOUT |
| ) |
| except asyncio.TimeoutError: |
| logger.warning( |
| f"Query generation timed out (attempt {attempt + 1}/{MAX_RETRIES})" |
| ) |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| continue |
| else: |
| raise |
|
|
| |
| |
| if hasattr(raw_queries, "dict"): |
| raw_queries_dict = cast(Dict[str, Any], raw_queries.dict()) |
| elif hasattr(raw_queries, "__dict__"): |
| raw_queries_dict = cast(Dict[str, Any], raw_queries.__dict__) |
| elif isinstance(raw_queries, dict): |
| raw_queries_dict = cast(Dict[str, Any], raw_queries) |
| else: |
| raw_queries_dict = cast(Dict[str, Any], dict(raw_queries)) |
|
|
| queries = parse_dspy_queries_with_fallback(raw_queries_dict, company_name) |
|
|
| if not queries: |
| logger.warning("No valid queries generated") |
| queries = get_fallback_queries(company_name) |
|
|
| logger.info( |
| f"Generated {len(queries)} search queries: {list(queries.keys())}" |
| ) |
|
|
| |
| search_task = asyncio.create_task( |
| asyncio.to_thread(tavily_search.tavily_search_company, queries) |
| ) |
|
|
| try: |
| search_results = await asyncio.wait_for( |
| search_task, timeout=QUERY_TIMEOUT * len(queries) |
| ) |
| except asyncio.TimeoutError: |
| logger.warning( |
| f"Search timed out (attempt {attempt + 1}/{MAX_RETRIES})" |
| ) |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| continue |
| else: |
| raise |
|
|
| |
| if not isinstance(search_results, list): |
| logger.warning(f"Invalid search results type: {type(search_results)}") |
| search_results = [] |
|
|
| if len(search_results) == 0: |
| logger.warning("No search results returned") |
|
|
| |
| cr = state.company_research_data or CompanyResearchData() |
| return { |
| "company_research_data": cr.model_copy(update={"tavily_search": search_results}), |
| "attempted_search_queries": list(queries.values()), |
| "current_node": "research_company", |
| } |
|
|
| except Exception as e: |
| logger.error( |
| f"Error in research_company (attempt {attempt + 1}/{MAX_RETRIES}): {e}", |
| exc_info=True, |
| ) |
|
|
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY * (attempt + 1)) |
| else: |
| logger.error("All retry attempts exhausted. Using empty results.") |
| cr = state.company_research_data or CompanyResearchData() |
| return { |
| "company_research_data": cr.model_copy(update={"tavily_search": []}), |
| "attempted_search_queries": [], |
| "current_node": "research_company", |
| } |
|
|
| cr = state.company_research_data or CompanyResearchData() |
| return { |
| "company_research_data": cr, |
| "attempted_search_queries": [], |
| "current_node": "research_company", |
| } |
|
|
|
|
| |
| research_subgraph = StateGraph(ResearchState) |
|
|
| |
| research_subgraph.add_node("research_company", research_company_with_retry) |
| research_subgraph.add_node("relevance_filter", filter_research_results_by_relevance) |
| research_subgraph.add_node( |
| "company_research_data_summary", company_research_data_summary |
| ) |
|
|
| |
| research_subgraph.set_entry_point("research_company") |
| research_subgraph.set_finish_point("company_research_data_summary") |
|
|
| |
| research_subgraph.add_edge("research_company", "relevance_filter") |
| research_subgraph.add_edge("relevance_filter", "company_research_data_summary") |
|
|
| |
| background_research_workflow = research_subgraph.compile() |
|
|