Spaces:
Sleeping
Sleeping
| """ | |
| RLHF-Enhanced LangGraph Workflows for FinRyver | |
| Integrates reward model and feedback collection into existing workflows | |
| """ | |
| from typing import TypedDict, Dict, Any, List, Annotated, Optional | |
| import time | |
| import uuid | |
| import os | |
| import logging | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| # Import existing tools and RLHF components | |
| from agents.simple_tools import ( | |
| generate_notes_full_pipeline_from_path, | |
| generate_balance_sheet, | |
| generate_pnl_statement, | |
| generate_cash_flow_statement, | |
| generate_llm_notes, | |
| ) | |
| from agents.feedback_manager import FeedbackManager | |
| from agents.reward_model import TextBasedRewardModel | |
| logger = logging.getLogger(__name__) | |
| class RLHFFinancialAgentState(TypedDict): | |
| """Enhanced state with RLHF capabilities""" | |
| messages: Annotated[List[BaseMessage], "History"] | |
| file_path: str | |
| result: Dict[str, Any] | |
| status: str | |
| start_time: float | |
| end_time: float | |
| error: str | |
| # RLHF-specific fields | |
| statement_id: Optional[str] | |
| predicted_quality: Optional[float] | |
| confidence_score: Optional[float] | |
| candidates_generated: Optional[List[Dict[str, Any]]] | |
| best_candidate_index: Optional[int] | |
| feedback_collected: Optional[bool] | |
| class RLHFWorkflowManager: | |
| """Manages RLHF-enhanced workflows with text-based feedback""" | |
| def __init__(self): | |
| self.feedback_manager = FeedbackManager() | |
| self.reward_model = TextBasedRewardModel() | |
| def collect_feedback(self, feedback_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Collect text-based feedback""" | |
| return self.reward_model.collect_feedback(feedback_data) | |
| def get_feedback_patterns(self) -> Dict[str, Any]: | |
| """Get feedback patterns and insights""" | |
| return self.reward_model.get_feedback_patterns() | |
| def get_rlhf_manager() -> RLHFWorkflowManager: | |
| """Get the RLHF workflow manager instance""" | |
| return RLHFWorkflowManager() | |
| def run_rlhf_workflow(file_path: str, kind: str, user_api_key: Optional[str] = None) -> Dict[str, Any]: | |
| """Run RLHF-enhanced workflow (placeholder - simplified)""" | |
| # For now, just return a basic structure | |
| # This can be enhanced later with actual RLHF logic | |
| return { | |
| "status": "error", | |
| "error": "RLHF workflow not implemented for this endpoint", | |
| "file_path": file_path, | |
| "kind": kind, | |
| "user_api_key": user_api_key | |
| } | |