Sahil Garg commited on
Commit
c172f37
·
1 Parent(s): 41cb3f5

initial RLHF applied

Browse files
.gitignore CHANGED
@@ -20,4 +20,11 @@ app/__pycache__/
20
  pnlbs/__pycache__/
21
  AGENT_GUIDE.md
22
  docker-compose.dev.yml
23
- file_cleanup.py
 
 
 
 
 
 
 
 
20
  pnlbs/__pycache__/
21
  AGENT_GUIDE.md
22
  docker-compose.dev.yml
23
+ file_cleanup.py
24
+ agents/langgraph_routes.py
25
+
26
+ # RLHF related data
27
+ data/feedback/
28
+ data/models/
29
+ *.pkl
30
+ *.joblib
RLHF_GUIDE.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RLHF (Reinforcement Learning from Human Feedback) Features
2
+
3
+ ## Overview
4
+
5
+ FinRyver now includes RLHF capabilities that allow the system to learn from human feedback and improve the quality of generated financial statements over time.
6
+
7
+ ## Key Components
8
+
9
+ ### 1. **Enhanced Workflows**
10
+ - RLHF-enhanced versions of all financial statement generation workflows
11
+ - Multiple candidate generation and selection using reward models
12
+ - Quality prediction and confidence scoring
13
+
14
+ ### 2. **Feedback Collection System**
15
+ - Web-based review interface for human feedback
16
+ - Structured feedback forms with technical and quality metrics
17
+ - Storage and management of feedback data
18
+
19
+ ### 3. **Reward Model**
20
+ - Machine learning model that predicts statement quality
21
+ - Trained on human feedback data
22
+ - Automatic retraining when sufficient new feedback is available
23
+
24
+ ## Usage
25
+
26
+ ### Basic Financial Statement Generation
27
+
28
+ **Standard workflow (existing functionality):**
29
+ ```bash
30
+ curl -X POST "http://localhost:8000/notes" \
31
+ -F "file=@trial_balance.xlsx"
32
+ ```
33
+
34
+ **RLHF-enhanced workflow:**
35
+ ```bash
36
+ curl -X POST "http://localhost:8000/notes?use_rlhf=true" \
37
+ -F "file=@trial_balance.xlsx"
38
+ ```
39
+
40
+ The RLHF-enhanced workflow will:
41
+ 1. Generate multiple candidates (if reward model is trained)
42
+ 2. Use the reward model to select the best candidate
43
+ 3. Provide quality predictions and confidence scores
44
+ 4. Store the result for potential human feedback
45
+
46
+ ### Response Headers
47
+
48
+ When using RLHF workflows, additional metadata is included in response headers:
49
+ - `X-RLHF-Statement-ID`: Unique ID for the generated statement
50
+ - `X-RLHF-Quality-Score`: Predicted quality score (1-5)
51
+ - `X-RLHF-Confidence`: Model confidence in the prediction
52
+
53
+ ### Feedback Collection
54
+
55
+ #### 1. Get Statements Needing Review
56
+ ```bash
57
+ curl "http://localhost:8000/rlhf/pending-reviews"
58
+ ```
59
+
60
+ #### 2. Review Interface
61
+ Visit: `http://localhost:8000/rlhf/review/{statement_id}`
62
+
63
+ This provides an HTML form for structured feedback collection.
64
+
65
+ #### 3. Submit Feedback Programmatically
66
+ ```bash
67
+ curl -X POST "http://localhost:8000/rlhf/feedback" \
68
+ -F "statement_id=123e4567-e89b-12d3-a456-426614174000" \
69
+ -F "calculation_accuracy=4" \
70
+ -F "account_classification=5" \
71
+ -F "statement_balance=4" \
72
+ -F "accounting_standards=4" \
73
+ -F "regulatory_compliance=5" \
74
+ -F "completeness=3" \
75
+ -F "professional_presentation=4" \
76
+ -F "would_accept_for_audit=true" \
77
+ -F "specific_errors=Minor formatting issues" \
78
+ -F "improvement_suggestions=Add more detailed notes"
79
+ ```
80
+
81
+ ### Monitoring and Statistics
82
+
83
+ #### Get Feedback Statistics
84
+ ```bash
85
+ curl "http://localhost:8000/rlhf/stats"
86
+ ```
87
+
88
+ Returns:
89
+ - Total feedback collected
90
+ - Average quality scores
91
+ - Audit approval rates
92
+ - Model training status
93
+ - Feature importance
94
+
95
+ #### Get Model Information
96
+ ```bash
97
+ curl "http://localhost:8000/rlhf/model-info"
98
+ ```
99
+
100
+ #### Manual Model Retraining
101
+ ```bash
102
+ curl -X POST "http://localhost:8000/rlhf/retrain"
103
+ ```
104
+
105
+ ## Feedback Metrics
106
+
107
+ ### Technical Accuracy (1-5 scale)
108
+ - **Calculation Accuracy**: Mathematical correctness
109
+ - **Account Classification**: Proper categorization of accounts
110
+ - **Statement Balance**: Internal consistency and reconciliation
111
+
112
+ ### Compliance (1-5 scale)
113
+ - **Accounting Standards**: GAAP/IFRS compliance
114
+ - **Regulatory Compliance**: Meeting regulatory requirements
115
+
116
+ ### Quality (1-5 scale)
117
+ - **Completeness**: All necessary items included
118
+ - **Professional Presentation**: Formatting and language quality
119
+
120
+ ### Qualitative Feedback
121
+ - **Specific Errors**: Detailed error descriptions
122
+ - **Missing Items**: Items that should be included
123
+ - **Improvement Suggestions**: Recommendations for enhancement
124
+ - **Audit Acceptance**: Binary approval for professional use
125
+
126
+ ## Training Process
127
+
128
+ 1. **Initial Phase**: System operates with default models
129
+ 2. **Feedback Collection**: Human experts review generated statements
130
+ 3. **Model Training**: When 20+ feedback samples are available, reward model is trained
131
+ 4. **Enhanced Generation**: RLHF workflows use trained model for better results
132
+ 5. **Continuous Learning**: Model retrains automatically with new feedback
133
+
134
+ ## Benefits
135
+
136
+ - **Quality Improvement**: Statements become more accurate over time
137
+ - **Domain Adaptation**: System learns specific requirements and preferences
138
+ - **Consistency**: Reduces variability in output quality
139
+ - **Professional Standards**: Aligns with human expert expectations
140
+
141
+ ## Implementation Notes
142
+
143
+ - RLHF features are optional and backward-compatible
144
+ - Existing workflows continue to work unchanged
145
+ - Feedback data is stored locally and can be exported for analysis
146
+ - Models can be backed up and restored
147
+ - Multiple reward models can be maintained for different statement types
148
+
149
+ ## File Structure
150
+
151
+ ```
152
+ data/
153
+ ├── feedback/
154
+ │ ├── human_feedback.json # Collected feedback data
155
+ │ └── generated_statements.json # Statement metadata
156
+ └── models/
157
+ ├── reward_model.pkl # Trained reward model
158
+ ├── feature_names.json # Model feature definitions
159
+ └── model_stats.json # Training statistics
160
+ ```
161
+
162
+ ## Security and Privacy
163
+
164
+ - Feedback data is stored locally
165
+ - No external transmission of financial data
166
+ - Anonymous feedback collection supported
167
+ - Data can be cleaned/anonymized before training
agents/feedback_manager.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLHF Feedback Management System for FinRyver
3
+ Handles collection, storage, and management of human feedback on financial statements
4
+ """
5
+ import json
6
+ import os
7
+ import time
8
+ import uuid
9
+ from typing import Dict, Any, List, Optional
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class FeedbackManager:
15
+ """Manages human feedback collection for RLHF training"""
16
+
17
+ def __init__(self, feedback_dir: str = "data/feedback"):
18
+ self.feedback_dir = feedback_dir
19
+ self.feedback_db = os.path.join(feedback_dir, "human_feedback.json")
20
+ self.statements_db = os.path.join(feedback_dir, "generated_statements.json")
21
+ os.makedirs(feedback_dir, exist_ok=True)
22
+
23
+ def store_generated_statement(self, statement_data: Dict[str, Any]) -> str:
24
+ """Store generated statement for later feedback collection"""
25
+ statement_id = str(uuid.uuid4())
26
+ statement_record = {
27
+ "statement_id": statement_id,
28
+ "timestamp": time.time(),
29
+ "statement_type": statement_data.get("type", "unknown"),
30
+ "file_path": statement_data.get("file_path"),
31
+ "output_path": statement_data.get("output_path"),
32
+ "generation_time": statement_data.get("generation_time", 0),
33
+ "metadata": statement_data.get("metadata", {})
34
+ }
35
+
36
+ # Load existing statements
37
+ statements = self._load_statements()
38
+ statements.append(statement_record)
39
+
40
+ # Save updated statements
41
+ with open(self.statements_db, "w") as f:
42
+ json.dump(statements, f, indent=2)
43
+
44
+ logger.info(f"Stored statement {statement_id} for feedback collection")
45
+ return statement_id
46
+
47
+ def store_feedback(self, feedback: Dict[str, Any]) -> str:
48
+ """Store human feedback for RLHF training"""
49
+ feedback_id = str(uuid.uuid4())
50
+ feedback_record = {
51
+ "feedback_id": feedback_id,
52
+ "statement_id": feedback.get("statement_id"),
53
+ "timestamp": time.time(),
54
+ "reviewer_id": feedback.get("reviewer_id", "anonymous"),
55
+
56
+ # Technical accuracy metrics
57
+ "calculation_accuracy": feedback.get("calculation_accuracy"),
58
+ "account_classification": feedback.get("account_classification"),
59
+ "statement_balance": feedback.get("statement_balance"),
60
+
61
+ # Compliance metrics
62
+ "accounting_standards": feedback.get("accounting_standards"),
63
+ "regulatory_compliance": feedback.get("regulatory_compliance"),
64
+
65
+ # Quality metrics
66
+ "completeness": feedback.get("completeness"),
67
+ "professional_presentation": feedback.get("professional_presentation"),
68
+
69
+ # Overall quality score (computed)
70
+ "overall_score": self._compute_overall_score(feedback),
71
+
72
+ # Qualitative feedback
73
+ "specific_errors": feedback.get("specific_errors", ""),
74
+ "missing_items": feedback.get("missing_items", ""),
75
+ "improvement_suggestions": feedback.get("improvement_suggestions", ""),
76
+ "would_accept_for_audit": feedback.get("would_accept_for_audit", False),
77
+
78
+ # Additional context
79
+ "statement_type": feedback.get("statement_type"),
80
+ "complexity_level": feedback.get("complexity_level", "medium")
81
+ }
82
+
83
+ # Load existing feedback
84
+ all_feedback = self._load_feedback()
85
+ all_feedback.append(feedback_record)
86
+
87
+ # Save updated feedback
88
+ with open(self.feedback_db, "w") as f:
89
+ json.dump(all_feedback, f, indent=2)
90
+
91
+ logger.info(f"Stored feedback {feedback_id} for statement {feedback.get('statement_id')}")
92
+ return feedback_id
93
+
94
+ def get_training_data(self, min_feedback_count: int = 2) -> List[Dict[str, Any]]:
95
+ """Get feedback data suitable for RLHF training"""
96
+ feedback_data = self._load_feedback()
97
+
98
+ if len(feedback_data) < min_feedback_count:
99
+ logger.warning(f"Only {len(feedback_data)} feedback samples available, need at least {min_feedback_count}")
100
+ return []
101
+
102
+ # Filter and prepare training data
103
+ training_data = []
104
+ for feedback in feedback_data:
105
+ if feedback.get("overall_score") is not None:
106
+ training_sample = {
107
+ "statement_id": feedback["statement_id"],
108
+ "statement_type": feedback["statement_type"],
109
+ "reward_score": feedback["overall_score"],
110
+ "binary_approval": feedback["would_accept_for_audit"],
111
+ "technical_metrics": {
112
+ "calculation_accuracy": feedback.get("calculation_accuracy"),
113
+ "account_classification": feedback.get("account_classification"),
114
+ "statement_balance": feedback.get("statement_balance")
115
+ },
116
+ "quality_metrics": {
117
+ "completeness": feedback.get("completeness"),
118
+ "professional_presentation": feedback.get("professional_presentation"),
119
+ "accounting_standards": feedback.get("accounting_standards")
120
+ },
121
+ "feedback_text": {
122
+ "errors": feedback.get("specific_errors", ""),
123
+ "missing": feedback.get("missing_items", ""),
124
+ "suggestions": feedback.get("improvement_suggestions", "")
125
+ }
126
+ }
127
+ training_data.append(training_sample)
128
+
129
+ return training_data
130
+
131
+ def get_statement_for_review(self, statement_id: str) -> Optional[Dict[str, Any]]:
132
+ """Get statement data for human review"""
133
+ statements = self._load_statements()
134
+ for statement in statements:
135
+ if statement["statement_id"] == statement_id:
136
+ return statement
137
+ return None
138
+
139
+ def get_pending_reviews(self, limit: int = 10) -> List[Dict[str, Any]]:
140
+ """Get statements that need human review"""
141
+ statements = self._load_statements()
142
+ feedback_data = self._load_feedback()
143
+
144
+ # Get statement IDs that already have feedback
145
+ reviewed_ids = {fb["statement_id"] for fb in feedback_data}
146
+
147
+ # Return statements without feedback
148
+ pending = [s for s in statements if s["statement_id"] not in reviewed_ids]
149
+ return pending[-limit:] # Return most recent
150
+
151
+ def get_feedback_stats(self) -> Dict[str, Any]:
152
+ """Get statistics about collected feedback"""
153
+ feedback_data = self._load_feedback()
154
+ statements = self._load_statements()
155
+
156
+ if not feedback_data:
157
+ return {"total_feedback": 0, "total_statements": len(statements)}
158
+
159
+ # Calculate statistics
160
+ scores = [fb["overall_score"] for fb in feedback_data if fb.get("overall_score")]
161
+ audit_approvals = [fb["would_accept_for_audit"] for fb in feedback_data]
162
+
163
+ stats = {
164
+ "total_feedback": len(feedback_data),
165
+ "total_statements": len(statements),
166
+ "avg_overall_score": sum(scores) / len(scores) if scores else 0,
167
+ "audit_approval_rate": sum(audit_approvals) / len(audit_approvals) if audit_approvals else 0,
168
+ "feedback_by_type": {},
169
+ "recent_trend": self._calculate_trend()
170
+ }
171
+
172
+ # Group by statement type
173
+ for fb in feedback_data:
174
+ stmt_type = fb.get("statement_type", "unknown")
175
+ if stmt_type not in stats["feedback_by_type"]:
176
+ stats["feedback_by_type"][stmt_type] = {"count": 0, "avg_score": 0}
177
+ stats["feedback_by_type"][stmt_type]["count"] += 1
178
+
179
+ return stats
180
+
181
+ def _load_feedback(self) -> List[Dict[str, Any]]:
182
+ """Load feedback from storage"""
183
+ if os.path.exists(self.feedback_db):
184
+ try:
185
+ with open(self.feedback_db, "r") as f:
186
+ return json.load(f)
187
+ except (json.JSONDecodeError, FileNotFoundError):
188
+ logger.warning("Could not load feedback database, starting fresh")
189
+ return []
190
+
191
+ def _load_statements(self) -> List[Dict[str, Any]]:
192
+ """Load statements from storage"""
193
+ if os.path.exists(self.statements_db):
194
+ try:
195
+ with open(self.statements_db, "r") as f:
196
+ return json.load(f)
197
+ except (json.JSONDecodeError, FileNotFoundError):
198
+ logger.warning("Could not load statements database, starting fresh")
199
+ return []
200
+
201
+ def _compute_overall_score(self, feedback: Dict[str, Any]) -> float:
202
+ """Compute overall quality score from individual metrics"""
203
+ metrics = [
204
+ feedback.get("calculation_accuracy"),
205
+ feedback.get("account_classification"),
206
+ feedback.get("statement_balance"),
207
+ feedback.get("accounting_standards"),
208
+ feedback.get("regulatory_compliance"),
209
+ feedback.get("completeness"),
210
+ feedback.get("professional_presentation")
211
+ ]
212
+
213
+ # Filter out None values
214
+ valid_metrics = [m for m in metrics if m is not None]
215
+
216
+ if not valid_metrics:
217
+ return 0.0
218
+
219
+ return sum(valid_metrics) / len(valid_metrics)
220
+
221
+ def _calculate_trend(self) -> Dict[str, float]:
222
+ """Calculate recent feedback trend"""
223
+ feedback_data = self._load_feedback()
224
+
225
+ if len(feedback_data) < 5:
226
+ return {"trend": "insufficient_data"}
227
+
228
+ # Sort by timestamp
229
+ sorted_feedback = sorted(feedback_data, key=lambda x: x.get("timestamp", 0))
230
+
231
+ # Compare recent vs older feedback
232
+ mid_point = len(sorted_feedback) // 2
233
+ older_scores = [fb["overall_score"] for fb in sorted_feedback[:mid_point] if fb.get("overall_score")]
234
+ recent_scores = [fb["overall_score"] for fb in sorted_feedback[mid_point:] if fb.get("overall_score")]
235
+
236
+ if older_scores and recent_scores:
237
+ older_avg = sum(older_scores) / len(older_scores)
238
+ recent_avg = sum(recent_scores) / len(recent_scores)
239
+ improvement = recent_avg - older_avg
240
+
241
+ return {
242
+ "older_average": older_avg,
243
+ "recent_average": recent_avg,
244
+ "improvement": improvement,
245
+ "trend": "improving" if improvement > 0.1 else "stable" if abs(improvement) <= 0.1 else "declining"
246
+ }
247
+
248
+ return {"trend": "insufficient_data"}
agents/reward_model.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLHF Reward Model for FinRyver
3
+ Predicts quality scores for generated financial statements based on human feedback
4
+ """
5
+ import json
6
+ import os
7
+ import logging
8
+ from typing import Dict, Any, List, Optional, Tuple
9
+ import numpy as np
10
+ from sklearn.ensemble import RandomForestRegressor
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.metrics import mean_squared_error, r2_score
13
+ import joblib
14
+ import time
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class FinancialRewardModel:
19
+ """
20
+ Reward model that predicts quality scores for financial statements
21
+ Uses traditional ML initially, can be upgraded to transformer-based models
22
+ """
23
+
24
+ def __init__(self, model_dir: str = "data/models"):
25
+ self.model_dir = model_dir
26
+ self.model_path = os.path.join(model_dir, "reward_model.pkl")
27
+ self.feature_names_path = os.path.join(model_dir, "feature_names.json")
28
+ self.model_stats_path = os.path.join(model_dir, "model_stats.json")
29
+
30
+ os.makedirs(model_dir, exist_ok=True)
31
+
32
+ # Initialize model
33
+ self.model = RandomForestRegressor(
34
+ n_estimators=100,
35
+ max_depth=10,
36
+ random_state=42,
37
+ n_jobs=-1
38
+ )
39
+
40
+ self.feature_names = []
41
+ self.is_trained = False
42
+ self.model_version = "1.0"
43
+
44
+ # Load existing model if available
45
+ self._load_model()
46
+
47
+ def extract_features(self, statement_data: Dict[str, Any], statement_content: str = "") -> np.ndarray:
48
+ """Extract features from statement data for reward prediction"""
49
+ features = []
50
+
51
+ # Basic metadata features
52
+ features.append(len(statement_content)) # Content length
53
+ features.append(statement_data.get("generation_time", 0)) # Generation time
54
+ features.append(1 if statement_data.get("statement_type") == "notes" else 0)
55
+ features.append(1 if statement_data.get("statement_type") == "balance_sheet" else 0)
56
+ features.append(1 if statement_data.get("statement_type") == "pnl" else 0)
57
+ features.append(1 if statement_data.get("statement_type") == "cash_flow" else 0)
58
+
59
+ # Content-based features (simple heuristics)
60
+ if statement_content:
61
+ features.append(statement_content.count("$")) # Number of monetary values
62
+ features.append(statement_content.count("\n")) # Number of lines
63
+ features.append(len(statement_content.split())) # Word count
64
+ features.append(statement_content.count(".")) # Number of sentences
65
+ features.append(statement_content.count(",")) # Number of commas (complexity indicator)
66
+
67
+ # Financial keywords
68
+ financial_keywords = ["asset", "liability", "equity", "revenue", "expense", "cash", "account"]
69
+ keyword_count = sum(statement_content.lower().count(keyword) for keyword in financial_keywords)
70
+ features.append(keyword_count)
71
+
72
+ # Professional language indicators
73
+ professional_words = ["accordance", "pursuant", "whereas", "therefore", "respective"]
74
+ professional_count = sum(statement_content.lower().count(word) for word in professional_words)
75
+ features.append(professional_count)
76
+ else:
77
+ # Default values if no content available
78
+ features.extend([0] * 7)
79
+
80
+ # File-based features (if available)
81
+ metadata = statement_data.get("metadata", {})
82
+ features.append(metadata.get("file_size", 0))
83
+ features.append(metadata.get("num_accounts", 0))
84
+ features.append(metadata.get("complexity_score", 0))
85
+
86
+ # Ensure we have consistent feature names
87
+ if not self.feature_names:
88
+ self.feature_names = [
89
+ "content_length", "generation_time", "is_notes", "is_balance_sheet",
90
+ "is_pnl", "is_cash_flow", "monetary_values", "line_count",
91
+ "word_count", "sentence_count", "comma_count", "financial_keywords",
92
+ "professional_words", "file_size", "num_accounts", "complexity_score"
93
+ ]
94
+
95
+ return np.array(features).reshape(1, -1)
96
+
97
+ def train_reward_model(self, training_data: List[Dict[str, Any]]) -> Dict[str, float]:
98
+ """Train reward model from human feedback data"""
99
+ if len(training_data) < 2: # Lowered from 10 to 2 for testing
100
+ logger.warning(f"Insufficient training data: {len(training_data)} samples")
101
+ return {"error": "insufficient_data", "sample_count": len(training_data)}
102
+
103
+ # Prepare training data
104
+ X = []
105
+ y = []
106
+
107
+ for sample in training_data:
108
+ # Create dummy statement data for feature extraction
109
+ statement_data = {
110
+ "statement_type": sample.get("statement_type", "unknown"),
111
+ "generation_time": sample.get("generation_time", 0),
112
+ "metadata": sample.get("metadata", {})
113
+ }
114
+
115
+ # Extract features
116
+ features = self.extract_features(statement_data, "")
117
+ X.append(features.flatten())
118
+ y.append(sample["reward_score"])
119
+
120
+ X = np.array(X)
121
+ y = np.array(y)
122
+
123
+ # Split data
124
+ if len(X) > 20:
125
+ X_train, X_test, y_train, y_test = train_test_split(
126
+ X, y, test_size=0.2, random_state=42
127
+ )
128
+ else:
129
+ X_train, X_test, y_train, y_test = X, X, y, y
130
+
131
+ # Train model
132
+ logger.info(f"Training reward model with {len(X_train)} samples")
133
+ self.model.fit(X_train, y_train)
134
+
135
+ # Evaluate model
136
+ train_pred = self.model.predict(X_train)
137
+ test_pred = self.model.predict(X_test)
138
+
139
+ metrics = {
140
+ "train_mse": mean_squared_error(y_train, train_pred),
141
+ "test_mse": mean_squared_error(y_test, test_pred),
142
+ "train_r2": r2_score(y_train, train_pred),
143
+ "test_r2": r2_score(y_test, test_pred),
144
+ "sample_count": len(training_data),
145
+ "feature_importance": dict(zip(self.feature_names, self.model.feature_importances_))
146
+ }
147
+
148
+ self.is_trained = True
149
+
150
+ # Save model
151
+ self._save_model(metrics)
152
+
153
+ logger.info(f"Reward model trained. R2 score: {metrics['test_r2']:.3f}")
154
+ return metrics
155
+
156
+ def predict_reward(self, statement_data: Dict[str, Any], statement_content: str = "") -> float:
157
+ """Predict reward score for a generated financial statement"""
158
+ if not self.is_trained:
159
+ logger.warning("Reward model not trained, returning default score")
160
+ return 3.0 # Default neutral score
161
+
162
+ try:
163
+ features = self.extract_features(statement_data, statement_content)
164
+ reward = self.model.predict(features)[0]
165
+
166
+ # Clamp to valid range [1, 5]
167
+ reward = max(1.0, min(5.0, reward))
168
+
169
+ return float(reward)
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error predicting reward: {e}")
173
+ return 3.0 # Default score on error
174
+
175
+ def predict_with_confidence(self, statement_data: Dict[str, Any], statement_content: str = "") -> Tuple[float, float]:
176
+ """Predict reward with confidence interval"""
177
+ if not self.is_trained:
178
+ return 3.0, 0.0
179
+
180
+ try:
181
+ features = self.extract_features(statement_data, statement_content)
182
+
183
+ # For Random Forest, we can get prediction from all trees
184
+ tree_predictions = [tree.predict(features)[0] for tree in self.model.estimators_]
185
+
186
+ reward = np.mean(tree_predictions)
187
+ confidence = 1.0 / (1.0 + np.std(tree_predictions)) # Higher std = lower confidence
188
+
189
+ reward = max(1.0, min(5.0, reward))
190
+
191
+ return float(reward), float(confidence)
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error predicting reward with confidence: {e}")
195
+ return 3.0, 0.0
196
+
197
+ def get_feature_importance(self) -> Dict[str, float]:
198
+ """Get feature importance from trained model"""
199
+ if not self.is_trained:
200
+ return {}
201
+
202
+ return dict(zip(self.feature_names, self.model.feature_importances_))
203
+
204
+ def get_model_stats(self) -> Dict[str, Any]:
205
+ """Get model training statistics"""
206
+ if os.path.exists(self.model_stats_path):
207
+ try:
208
+ with open(self.model_stats_path, "r") as f:
209
+ return json.load(f)
210
+ except:
211
+ pass
212
+ return {"status": "not_trained"}
213
+
214
+ def _save_model(self, training_stats: Dict[str, Any]):
215
+ """Save trained model and metadata"""
216
+ try:
217
+ # Save model
218
+ joblib.dump(self.model, self.model_path)
219
+
220
+ # Save feature names
221
+ with open(self.feature_names_path, "w") as f:
222
+ json.dump(self.feature_names, f)
223
+
224
+ # Save training stats
225
+ stats = {
226
+ "model_version": self.model_version,
227
+ "training_timestamp": time.time(),
228
+ "is_trained": True,
229
+ **training_stats
230
+ }
231
+
232
+ with open(self.model_stats_path, "w") as f:
233
+ json.dump(stats, f, indent=2)
234
+
235
+ logger.info("Reward model saved successfully")
236
+
237
+ except Exception as e:
238
+ logger.error(f"Error saving model: {e}")
239
+
240
+ def _load_model(self):
241
+ """Load existing trained model"""
242
+ try:
243
+ if os.path.exists(self.model_path) and os.path.exists(self.feature_names_path):
244
+ self.model = joblib.load(self.model_path)
245
+
246
+ with open(self.feature_names_path, "r") as f:
247
+ self.feature_names = json.load(f)
248
+
249
+ self.is_trained = True
250
+ logger.info("Existing reward model loaded successfully")
251
+
252
+ except Exception as e:
253
+ logger.warning(f"Could not load existing model: {e}")
254
+ self.is_trained = False
255
+
256
+
257
+ class RLHFTrainer:
258
+ """Coordinates RLHF training pipeline"""
259
+
260
+ def __init__(self, feedback_manager, reward_model):
261
+ self.feedback_manager = feedback_manager
262
+ self.reward_model = reward_model
263
+ self.min_feedback_threshold = 2 # Lowered for testing (was 20)
264
+
265
+ def should_retrain(self) -> bool:
266
+ """Determine if model should be retrained"""
267
+ stats = self.feedback_manager.get_feedback_stats()
268
+
269
+ # Check if we have enough new feedback
270
+ total_feedback = stats.get("total_feedback", 0)
271
+
272
+ # Get last training count
273
+ model_stats = self.reward_model.get_model_stats()
274
+ last_training_count = model_stats.get("sample_count", 0)
275
+
276
+ new_feedback_count = total_feedback - last_training_count
277
+
278
+ return (total_feedback >= self.min_feedback_threshold and
279
+ new_feedback_count >= 2) # At least 2 new samples (was 10)
280
+
281
+ def retrain_model(self) -> Dict[str, Any]:
282
+ """Retrain reward model with latest feedback"""
283
+ training_data = self.feedback_manager.get_training_data()
284
+
285
+ if len(training_data) < self.min_feedback_threshold:
286
+ return {
287
+ "status": "insufficient_data",
288
+ "current_count": len(training_data),
289
+ "required_count": self.min_feedback_threshold
290
+ }
291
+
292
+ # Train model
293
+ metrics = self.reward_model.train_reward_model(training_data)
294
+
295
+ return {
296
+ "status": "success",
297
+ "training_metrics": metrics,
298
+ "timestamp": time.time()
299
+ }
300
+
301
+ def periodic_training_check(self) -> Dict[str, Any]:
302
+ """Check if retraining is needed and perform if necessary"""
303
+ if self.should_retrain():
304
+ logger.info("Initiating automatic model retraining")
305
+ return self.retrain_model()
306
+ else:
307
+ return {"status": "no_retraining_needed"}
agents/rlhf_routes.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLHF Feedback Collection Routes for FinRyver
3
+ Handles human feedback collection for financial statement quality
4
+ """
5
+ from fastapi import APIRouter, HTTPException, Form, Query, Request
6
+ from fastapi.responses import JSONResponse, HTMLResponse
7
+ from typing import Optional, Dict, Any
8
+ import logging
9
+ from agents.feedback_manager import FeedbackManager
10
+ from agents.reward_model import FinancialRewardModel, RLHFTrainer
11
+ from agents.rlhf_workflows import get_rlhf_manager
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Create RLHF router
16
+ rlhf_router = APIRouter(prefix="/rlhf", tags=["RLHF Feedback"])
17
+
18
+ # Initialize components
19
+ feedback_manager = FeedbackManager()
20
+ reward_model = FinancialRewardModel()
21
+ trainer = RLHFTrainer(feedback_manager, reward_model)
22
+
23
+ @rlhf_router.post("/feedback")
24
+ async def collect_feedback(
25
+ statement_id: str = Form(...),
26
+ reviewer_id: str = Form("anonymous"),
27
+
28
+ # Technical accuracy metrics (1-5 scale)
29
+ calculation_accuracy: float = Form(..., ge=1, le=5),
30
+ account_classification: float = Form(..., ge=1, le=5),
31
+ statement_balance: float = Form(..., ge=1, le=5),
32
+
33
+ # Compliance metrics (1-5 scale)
34
+ accounting_standards: float = Form(..., ge=1, le=5),
35
+ regulatory_compliance: float = Form(..., ge=1, le=5),
36
+
37
+ # Quality metrics (1-5 scale)
38
+ completeness: float = Form(..., ge=1, le=5),
39
+ professional_presentation: float = Form(..., ge=1, le=5),
40
+
41
+ # Qualitative feedback
42
+ specific_errors: str = Form(""),
43
+ missing_items: str = Form(""),
44
+ improvement_suggestions: str = Form(""),
45
+
46
+ # Binary approval
47
+ would_accept_for_audit: bool = Form(False),
48
+
49
+ # Additional context
50
+ complexity_level: str = Form("medium") # low, medium, high
51
+ ):
52
+ """
53
+ Collect detailed human feedback on generated financial statements
54
+ This feedback is used to train and improve the AI models
55
+ """
56
+ try:
57
+ # Get statement info
58
+ statement_info = feedback_manager.get_statement_for_review(statement_id)
59
+ if not statement_info:
60
+ raise HTTPException(status_code=404, detail="Statement not found")
61
+
62
+ # Prepare feedback data
63
+ feedback_data = {
64
+ "statement_id": statement_id,
65
+ "reviewer_id": reviewer_id,
66
+ "calculation_accuracy": calculation_accuracy,
67
+ "account_classification": account_classification,
68
+ "statement_balance": statement_balance,
69
+ "accounting_standards": accounting_standards,
70
+ "regulatory_compliance": regulatory_compliance,
71
+ "completeness": completeness,
72
+ "professional_presentation": professional_presentation,
73
+ "specific_errors": specific_errors,
74
+ "missing_items": missing_items,
75
+ "improvement_suggestions": improvement_suggestions,
76
+ "would_accept_for_audit": would_accept_for_audit,
77
+ "statement_type": statement_info.get("statement_type"),
78
+ "complexity_level": complexity_level
79
+ }
80
+
81
+ # Store feedback
82
+ feedback_id = feedback_manager.store_feedback(feedback_data)
83
+
84
+ # Check if model should be retrained
85
+ retrain_result = trainer.periodic_training_check()
86
+
87
+ return {
88
+ "status": "success",
89
+ "feedback_id": feedback_id,
90
+ "message": "Feedback collected successfully",
91
+ "model_retrain_status": retrain_result.get("status"),
92
+ "overall_score": feedback_manager._compute_overall_score(feedback_data)
93
+ }
94
+
95
+ except Exception as e:
96
+ logger.error(f"Error collecting feedback: {e}")
97
+ raise HTTPException(status_code=500, detail=f"Error collecting feedback: {str(e)}")
98
+
99
+ @rlhf_router.get("/review/{statement_id}")
100
+ async def get_review_interface(statement_id: str):
101
+ """
102
+ Get a review interface for human feedback collection
103
+ Returns HTML form for statement review
104
+ """
105
+ try:
106
+ statement_info = feedback_manager.get_statement_for_review(statement_id)
107
+ if not statement_info:
108
+ raise HTTPException(status_code=404, detail="Statement not found")
109
+
110
+ # Generate HTML review form
111
+ html_content = generate_review_html(statement_id, statement_info)
112
+ return HTMLResponse(content=html_content)
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error getting review interface: {e}")
116
+ raise HTTPException(status_code=500, detail=str(e))
117
+
118
+ @rlhf_router.get("/pending-reviews")
119
+ async def get_pending_reviews(limit: int = Query(10, ge=1, le=50)):
120
+ """
121
+ Get statements that need human review
122
+ """
123
+ try:
124
+ pending_statements = feedback_manager.get_pending_reviews(limit)
125
+ return {
126
+ "status": "success",
127
+ "pending_reviews": pending_statements,
128
+ "count": len(pending_statements)
129
+ }
130
+ except Exception as e:
131
+ logger.error(f"Error getting pending reviews: {e}")
132
+ raise HTTPException(status_code=500, detail=str(e))
133
+
134
+ @rlhf_router.get("/stats")
135
+ async def get_feedback_stats():
136
+ """
137
+ Get feedback and model training statistics
138
+ """
139
+ try:
140
+ feedback_stats = feedback_manager.get_feedback_stats()
141
+ model_stats = reward_model.get_model_stats()
142
+ feature_importance = reward_model.get_feature_importance()
143
+
144
+ return {
145
+ "status": "success",
146
+ "feedback_stats": feedback_stats,
147
+ "model_stats": model_stats,
148
+ "feature_importance": feature_importance,
149
+ "model_trained": reward_model.is_trained
150
+ }
151
+ except Exception as e:
152
+ logger.error(f"Error getting stats: {e}")
153
+ raise HTTPException(status_code=500, detail=str(e))
154
+
155
+ @rlhf_router.post("/retrain")
156
+ async def manual_retrain():
157
+ """
158
+ Manually trigger model retraining
159
+ """
160
+ try:
161
+ result = trainer.retrain_model()
162
+ return {
163
+ "status": "success",
164
+ "retrain_result": result
165
+ }
166
+ except Exception as e:
167
+ logger.error(f"Error during manual retrain: {e}")
168
+ raise HTTPException(status_code=500, detail=str(e))
169
+
170
+ @rlhf_router.get("/model-info")
171
+ async def get_model_info():
172
+ """
173
+ Get information about the current reward model
174
+ """
175
+ try:
176
+ return {
177
+ "status": "success",
178
+ "model_trained": reward_model.is_trained,
179
+ "model_version": reward_model.model_version,
180
+ "feature_count": len(reward_model.feature_names),
181
+ "feature_names": reward_model.feature_names,
182
+ "model_stats": reward_model.get_model_stats()
183
+ }
184
+ except Exception as e:
185
+ logger.error(f"Error getting model info: {e}")
186
+ raise HTTPException(status_code=500, detail=str(e))
187
+
188
+ def generate_review_html(statement_id: str, statement_info: Dict) -> str:
189
+ """Generate HTML form for statement review"""
190
+ return f"""
191
+ <!DOCTYPE html>
192
+ <html>
193
+ <head>
194
+ <title>FinRyver - Statement Review</title>
195
+ <style>
196
+ body {{ font-family: Arial, sans-serif; margin: 40px; }}
197
+ .form-group {{ margin: 15px 0; }}
198
+ label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
199
+ input, select, textarea {{ width: 100%; padding: 8px; margin-bottom: 10px; }}
200
+ .rating {{ display: flex; gap: 10px; }}
201
+ .rating input {{ width: auto; }}
202
+ button {{ background-color: #007bff; color: white; padding: 10px 20px; border: none; cursor: pointer; }}
203
+ .statement-info {{ background-color: #f8f9fa; padding: 15px; margin-bottom: 20px; border-radius: 5px; }}
204
+ </style>
205
+ </head>
206
+ <body>
207
+ <h1>Financial Statement Review</h1>
208
+
209
+ <div class="statement-info">
210
+ <h3>Statement Information</h3>
211
+ <p><strong>Statement ID:</strong> {statement_id}</p>
212
+ <p><strong>Type:</strong> {statement_info.get('statement_type', 'Unknown')}</p>
213
+ <p><strong>Generated:</strong> {statement_info.get('timestamp', 'Unknown')}</p>
214
+ <p><strong>File:</strong> {statement_info.get('file_path', 'Unknown')}</p>
215
+ </div>
216
+
217
+ <form action="/rlhf/feedback" method="post">
218
+ <input type="hidden" name="statement_id" value="{statement_id}">
219
+
220
+ <div class="form-group">
221
+ <label>Reviewer ID (optional):</label>
222
+ <input type="text" name="reviewer_id" placeholder="Enter your identifier">
223
+ </div>
224
+
225
+ <h3>Technical Accuracy (1-5 scale)</h3>
226
+
227
+ <div class="form-group">
228
+ <label>Calculation Accuracy:</label>
229
+ <select name="calculation_accuracy" required>
230
+ <option value="">Select rating</option>
231
+ <option value="1">1 - Major calculation errors</option>
232
+ <option value="2">2 - Some calculation errors</option>
233
+ <option value="3">3 - Minor calculation issues</option>
234
+ <option value="4">4 - Mostly accurate calculations</option>
235
+ <option value="5">5 - All calculations correct</option>
236
+ </select>
237
+ </div>
238
+
239
+ <div class="form-group">
240
+ <label>Account Classification:</label>
241
+ <select name="account_classification" required>
242
+ <option value="">Select rating</option>
243
+ <option value="1">1 - Major classification errors</option>
244
+ <option value="2">2 - Some classification errors</option>
245
+ <option value="3">3 - Minor classification issues</option>
246
+ <option value="4">4 - Mostly correct classification</option>
247
+ <option value="5">5 - Perfect classification</option>
248
+ </select>
249
+ </div>
250
+
251
+ <div class="form-group">
252
+ <label>Statement Balance/Reconciliation:</label>
253
+ <select name="statement_balance" required>
254
+ <option value="">Select rating</option>
255
+ <option value="1">1 - Does not balance</option>
256
+ <option value="2">2 - Major balance issues</option>
257
+ <option value="3">3 - Minor balance issues</option>
258
+ <option value="4">4 - Mostly balanced</option>
259
+ <option value="5">5 - Perfect balance</option>
260
+ </select>
261
+ </div>
262
+
263
+ <h3>Compliance & Standards (1-5 scale)</h3>
264
+
265
+ <div class="form-group">
266
+ <label>Accounting Standards Compliance:</label>
267
+ <select name="accounting_standards" required>
268
+ <option value="">Select rating</option>
269
+ <option value="1">1 - Major compliance issues</option>
270
+ <option value="2">2 - Some compliance issues</option>
271
+ <option value="3">3 - Minor compliance issues</option>
272
+ <option value="4">4 - Mostly compliant</option>
273
+ <option value="5">5 - Fully compliant</option>
274
+ </select>
275
+ </div>
276
+
277
+ <div class="form-group">
278
+ <label>Regulatory Compliance:</label>
279
+ <select name="regulatory_compliance" required>
280
+ <option value="">Select rating</option>
281
+ <option value="1">1 - Major regulatory issues</option>
282
+ <option value="2">2 - Some regulatory issues</option>
283
+ <option value="3">3 - Minor regulatory issues</option>
284
+ <option value="4">4 - Mostly compliant</option>
285
+ <option value="5">5 - Fully compliant</option>
286
+ </select>
287
+ </div>
288
+
289
+ <h3>Quality & Presentation (1-5 scale)</h3>
290
+
291
+ <div class="form-group">
292
+ <label>Completeness:</label>
293
+ <select name="completeness" required>
294
+ <option value="">Select rating</option>
295
+ <option value="1">1 - Major items missing</option>
296
+ <option value="2">2 - Some items missing</option>
297
+ <option value="3">3 - Minor items missing</option>
298
+ <option value="4">4 - Mostly complete</option>
299
+ <option value="5">5 - Complete</option>
300
+ </select>
301
+ </div>
302
+
303
+ <div class="form-group">
304
+ <label>Professional Presentation:</label>
305
+ <select name="professional_presentation" required>
306
+ <option value="">Select rating</option>
307
+ <option value="1">1 - Unprofessional</option>
308
+ <option value="2">2 - Below standard</option>
309
+ <option value="3">3 - Adequate</option>
310
+ <option value="4">4 - Good presentation</option>
311
+ <option value="5">5 - Excellent presentation</option>
312
+ </select>
313
+ </div>
314
+
315
+ <h3>Detailed Feedback</h3>
316
+
317
+ <div class="form-group">
318
+ <label>Specific Errors (if any):</label>
319
+ <textarea name="specific_errors" rows="3" placeholder="Describe any specific errors found..."></textarea>
320
+ </div>
321
+
322
+ <div class="form-group">
323
+ <label>Missing Items (if any):</label>
324
+ <textarea name="missing_items" rows="3" placeholder="List any missing items or information..."></textarea>
325
+ </div>
326
+
327
+ <div class="form-group">
328
+ <label>Improvement Suggestions:</label>
329
+ <textarea name="improvement_suggestions" rows="3" placeholder="Suggest improvements..."></textarea>
330
+ </div>
331
+
332
+ <div class="form-group">
333
+ <label>Complexity Level:</label>
334
+ <select name="complexity_level">
335
+ <option value="low">Low</option>
336
+ <option value="medium" selected>Medium</option>
337
+ <option value="high">High</option>
338
+ </select>
339
+ </div>
340
+
341
+ <div class="form-group">
342
+ <label>
343
+ <input type="checkbox" name="would_accept_for_audit" value="true">
344
+ Would accept this statement for audit/compliance purposes
345
+ </label>
346
+ </div>
347
+
348
+ <button type="submit">Submit Feedback</button>
349
+ </form>
350
+ </body>
351
+ </html>
352
+ """
agents/rlhf_workflows.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLHF-Enhanced LangGraph Workflows for FinRyver
3
+ Integrates reward model and feedback collection into existing workflows
4
+ """
5
+ from typing import TypedDict, Dict, Any, List, Annotated, Optional
6
+ import time
7
+ import uuid
8
+ import os
9
+ import logging
10
+ from langgraph.graph import StateGraph, END
11
+ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
12
+
13
+ # Import existing tools and RLHF components
14
+ from agents.simple_tools import (
15
+ generate_notes_full_pipeline_from_path,
16
+ generate_balance_sheet,
17
+ generate_pnl_statement,
18
+ generate_cash_flow_statement,
19
+ )
20
+ from agents.feedback_manager import FeedbackManager
21
+ from agents.reward_model import FinancialRewardModel, RLHFTrainer
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class RLHFFinancialAgentState(TypedDict):
26
+ """Enhanced state with RLHF capabilities"""
27
+ messages: Annotated[List[BaseMessage], "History"]
28
+ file_path: str
29
+ result: Dict[str, Any]
30
+ status: str
31
+ start_time: float
32
+ end_time: float
33
+ error: str
34
+
35
+ # RLHF-specific fields
36
+ statement_id: Optional[str]
37
+ predicted_quality: Optional[float]
38
+ confidence_score: Optional[float]
39
+ candidates_generated: Optional[List[Dict[str, Any]]]
40
+ best_candidate_index: Optional[int]
41
+ feedback_collected: Optional[bool]
42
+
43
+ class RLHFWorkflowManager:
44
+ """Manages RLHF-enhanced workflows"""
45
+
46
+ def __init__(self):
47
+ self.feedback_manager = FeedbackManager()
48
+ self.reward_model = FinancialRewardModel()
49
+ self.trainer = RLHFTrainer(self.feedback_manager, self.reward_model)
50
+
51
+ # Check for model retraining on initialization
52
+ self._check_and_retrain()
53
+
54
+ def _check_and_retrain(self):
55
+ """Check if model needs retraining"""
56
+ try:
57
+ result = self.trainer.periodic_training_check()
58
+ if result.get("status") == "success":
59
+ logger.info("Reward model retrained successfully")
60
+ except Exception as e:
61
+ logger.error(f"Error during model retraining check: {e}")
62
+
63
+ def make_rlhf_workflow(self, tool_func, statement_type: str):
64
+ """Create RLHF-enhanced workflow"""
65
+
66
+ def rlhf_node(state: RLHFFinancialAgentState) -> RLHFFinancialAgentState:
67
+ state["start_time"] = time.time()
68
+ state["statement_id"] = str(uuid.uuid4())
69
+
70
+ try:
71
+ # Generate multiple candidates if reward model is trained
72
+ if self.reward_model.is_trained:
73
+ candidates = self._generate_candidates(tool_func, state, num_candidates=3)
74
+ state["candidates_generated"] = candidates
75
+
76
+ # Select best candidate using reward model
77
+ best_candidate, best_index = self._select_best_candidate(
78
+ candidates, statement_type, state["file_path"]
79
+ )
80
+
81
+ state["result"] = best_candidate
82
+ state["best_candidate_index"] = best_index
83
+
84
+ else:
85
+ # Single generation if no trained model
86
+ result = tool_func.invoke({"file_path": state["file_path"]})
87
+ state["result"] = result
88
+ state["candidates_generated"] = [result]
89
+ state["best_candidate_index"] = 0
90
+
91
+ # Predict quality score
92
+ if state["result"].get("status") == "success":
93
+ predicted_quality, confidence = self._predict_quality(
94
+ state["result"], statement_type, state["file_path"]
95
+ )
96
+ state["predicted_quality"] = predicted_quality
97
+ state["confidence_score"] = confidence
98
+ state["status"] = "success"
99
+
100
+ # Store statement for potential feedback
101
+ self._store_for_feedback(state, statement_type)
102
+
103
+ else:
104
+ state["status"] = "error"
105
+ state["error"] = state["result"].get("error", "Unknown error")
106
+
107
+ except Exception as e:
108
+ state["status"] = "error"
109
+ state["error"] = str(e)
110
+ logger.error(f"Error in RLHF workflow: {e}")
111
+
112
+ state["end_time"] = time.time()
113
+ return state
114
+
115
+ # Create workflow graph
116
+ wf = StateGraph(RLHFFinancialAgentState)
117
+ wf.add_node("rlhf_run", rlhf_node)
118
+ wf.set_entry_point("rlhf_run")
119
+ wf.add_edge("rlhf_run", END)
120
+ return wf.compile()
121
+
122
+ def _generate_candidates(self, tool_func, state: RLHFFinancialAgentState, num_candidates: int = 3) -> List[Dict[str, Any]]:
123
+ """Generate multiple candidates for comparison"""
124
+ candidates = []
125
+
126
+ for i in range(num_candidates):
127
+ try:
128
+ result = tool_func.invoke({"file_path": state["file_path"]})
129
+ candidates.append({
130
+ "index": i,
131
+ "result": result,
132
+ "timestamp": time.time()
133
+ })
134
+ except Exception as e:
135
+ logger.warning(f"Failed to generate candidate {i}: {e}")
136
+ candidates.append({
137
+ "index": i,
138
+ "result": {"status": "error", "error": str(e)},
139
+ "timestamp": time.time()
140
+ })
141
+
142
+ return candidates
143
+
144
+ def _select_best_candidate(self, candidates: List[Dict[str, Any]], statement_type: str, file_path: str) -> tuple:
145
+ """Select best candidate using reward model"""
146
+ best_candidate = None
147
+ best_score = -1
148
+ best_index = 0
149
+
150
+ for candidate in candidates:
151
+ if candidate["result"].get("status") == "success":
152
+ # Create statement data for reward prediction
153
+ statement_data = {
154
+ "statement_type": statement_type,
155
+ "file_path": file_path,
156
+ "generation_time": 0, # Could be calculated from timestamps
157
+ "metadata": {}
158
+ }
159
+
160
+ # Predict reward
161
+ predicted_reward, confidence = self.reward_model.predict_with_confidence(
162
+ statement_data, ""
163
+ )
164
+
165
+ # Weight by confidence
166
+ weighted_score = predicted_reward * confidence
167
+
168
+ if weighted_score > best_score:
169
+ best_score = weighted_score
170
+ best_candidate = candidate["result"]
171
+ best_index = candidate["index"]
172
+
173
+ # Fallback to first successful candidate
174
+ if best_candidate is None:
175
+ for candidate in candidates:
176
+ if candidate["result"].get("status") == "success":
177
+ best_candidate = candidate["result"]
178
+ best_index = candidate["index"]
179
+ break
180
+
181
+ # Final fallback
182
+ if best_candidate is None and candidates:
183
+ best_candidate = candidates[0]["result"]
184
+ best_index = 0
185
+
186
+ return best_candidate, best_index
187
+
188
+ def _predict_quality(self, result: Dict[str, Any], statement_type: str, file_path: str) -> tuple:
189
+ """Predict quality score for generated statement"""
190
+ statement_data = {
191
+ "statement_type": statement_type,
192
+ "file_path": file_path,
193
+ "generation_time": 0,
194
+ "metadata": {}
195
+ }
196
+
197
+ return self.reward_model.predict_with_confidence(statement_data, "")
198
+
199
+ def _store_for_feedback(self, state: RLHFFinancialAgentState, statement_type: str):
200
+ """Store generated statement for feedback collection"""
201
+ try:
202
+ statement_data = {
203
+ "type": statement_type,
204
+ "file_path": state["file_path"],
205
+ "output_path": state["result"].get("output_path"),
206
+ "generation_time": state["end_time"] - state["start_time"],
207
+ "predicted_quality": state.get("predicted_quality"),
208
+ "confidence_score": state.get("confidence_score"),
209
+ "metadata": {
210
+ "candidates_count": len(state.get("candidates_generated", [])),
211
+ "best_candidate_index": state.get("best_candidate_index"),
212
+ "workflow_version": "rlhf_v1"
213
+ }
214
+ }
215
+
216
+ stored_id = self.feedback_manager.store_generated_statement(statement_data)
217
+ state["statement_id"] = stored_id
218
+
219
+ except Exception as e:
220
+ logger.error(f"Error storing statement for feedback: {e}")
221
+
222
+ # Global RLHF manager instance
223
+ rlhf_manager = RLHFWorkflowManager()
224
+
225
+ # RLHF-enhanced workflows
226
+ rlhf_workflows = {
227
+ "notes": rlhf_manager.make_rlhf_workflow(generate_notes_full_pipeline_from_path, "notes"),
228
+ "pnl": rlhf_manager.make_rlhf_workflow(generate_pnl_statement, "pnl"),
229
+ "bs": rlhf_manager.make_rlhf_workflow(generate_balance_sheet, "balance_sheet"),
230
+ "cf": rlhf_manager.make_rlhf_workflow(generate_cash_flow_statement, "cash_flow"),
231
+ }
232
+
233
+ def run_rlhf_workflow(file_path: str, kind: str) -> Dict[str, Any]:
234
+ """Run RLHF-enhanced workflow"""
235
+ state = RLHFFinancialAgentState(
236
+ messages=[HumanMessage(content=f"Run RLHF {kind} for {file_path}")],
237
+ file_path=file_path,
238
+ result={},
239
+ status="",
240
+ start_time=0,
241
+ end_time=0,
242
+ error="",
243
+ statement_id=None,
244
+ predicted_quality=None,
245
+ confidence_score=None,
246
+ candidates_generated=None,
247
+ best_candidate_index=None,
248
+ feedback_collected=False
249
+ )
250
+
251
+ final_state = rlhf_workflows[kind].invoke(state)
252
+
253
+ # Add RLHF metadata to result
254
+ if final_state["status"] == "success":
255
+ final_state["result"]["rlhf_metadata"] = {
256
+ "statement_id": final_state.get("statement_id"),
257
+ "predicted_quality": final_state.get("predicted_quality"),
258
+ "confidence_score": final_state.get("confidence_score"),
259
+ "candidates_generated": len(final_state.get("candidates_generated", [])),
260
+ "model_used": "rlhf_enhanced"
261
+ }
262
+
263
+ return final_state
264
+
265
+ def get_rlhf_manager() -> RLHFWorkflowManager:
266
+ """Get global RLHF manager instance"""
267
+ return rlhf_manager
app.py CHANGED
@@ -1,9 +1,11 @@
1
- from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException
2
  from fastapi.responses import FileResponse
3
  import os
4
  import shutil
5
  import logging
6
  from agents.langgraph import run_workflow
 
 
7
 
8
  # Configure logging for the application
9
  logging.basicConfig(level=logging.INFO)
@@ -13,9 +15,12 @@ logger = logging.getLogger("financial_notes_api")
13
 
14
  app = FastAPI(
15
  title="Financial Notes Generator API",
16
- description="API for generating financial notes, balance sheets, cash flow statements, and P&L reports.",
17
  version="1.0.0"
18
  )
 
 
 
19
  @app.on_event("startup")
20
  async def startup_event():
21
  logger.info("Financial Notes Generator API has started.")
@@ -81,34 +86,70 @@ async def llm_generate_and_excel(
81
 
82
 
83
  @router.post("/notes")
84
- async def notes_route(file: UploadFile = File(...)):
85
  file_path = f"data/input/{file.filename}"
86
  os.makedirs("data/input", exist_ok=True)
87
  with open(file_path, "wb") as buffer:
88
  shutil.copyfileobj(file.file, buffer)
89
- result = run_workflow(file_path, "notes")
 
 
 
 
 
 
90
  if result["status"] == "success":
91
- return FileResponse(result["result"]["output_xlsx_path"], filename=os.path.basename(result["result"]["output_xlsx_path"]))
 
 
 
 
 
 
 
 
 
92
  raise HTTPException(status_code=500, detail=result["error"])
93
 
94
  @router.post("/pnl")
95
- async def pnl_route(file: UploadFile = File(...)):
96
  file_path = f"data/input/{file.filename}"
97
  os.makedirs("data/input", exist_ok=True)
98
  with open(file_path, "wb") as buffer:
99
  shutil.copyfileobj(file.file, buffer)
100
- result = run_workflow(file_path, "pnl")
 
 
 
 
 
 
101
  if result["status"] == "success":
102
- return FileResponse(result["result"].get("output_path", "data/pnl_statement.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/pnl_statement.xlsx")))
 
 
 
 
 
 
 
 
 
103
  raise HTTPException(status_code=500, detail=result["error"])
104
 
105
  @router.post("/bs")
106
- async def bs_route(file: UploadFile = File(...)):
107
  file_path = f"data/input/{file.filename}"
108
  os.makedirs("data/input", exist_ok=True)
109
  with open(file_path, "wb") as buffer:
110
  shutil.copyfileobj(file.file, buffer)
111
- result = run_workflow(file_path, "bs")
 
 
 
 
 
 
112
  if result["status"] == "success":
113
  # Use first xlsx file in output dir if present
114
  output_file = result["result"].get("output_path")
@@ -120,19 +161,44 @@ async def bs_route(file: UploadFile = File(...)):
120
  output_file = os.path.join(output_dir, xlsx_files[0])
121
  else:
122
  raise HTTPException(status_code=500, detail="No balance sheet Excel file produced")
123
- return FileResponse(output_file, filename=os.path.basename(output_file))
 
 
 
 
 
 
 
 
 
 
124
  else:
125
  raise HTTPException(status_code=500, detail=result["error"])
126
 
127
  @router.post("/cf")
128
- async def cf_route(file: UploadFile = File(...)):
129
  file_path = f"data/input/{file.filename}"
130
  os.makedirs("data/input", exist_ok=True)
131
  with open(file_path, "wb") as buffer:
132
  shutil.copyfileobj(file.file, buffer)
133
- result = run_workflow(file_path, "cf")
 
 
 
 
 
 
134
  if result["status"] == "success":
135
- return FileResponse(result["result"].get("output_path", "data/cash_flow_statements.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/cash_flow_statements.xlsx")))
 
 
 
 
 
 
 
 
 
136
  raise HTTPException(status_code=500, detail=result["error"])
137
  app.include_router(router)
138
 
 
1
+ from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException, Query
2
  from fastapi.responses import FileResponse
3
  import os
4
  import shutil
5
  import logging
6
  from agents.langgraph import run_workflow
7
+ from agents.rlhf_workflows import run_rlhf_workflow
8
+ from agents.rlhf_routes import rlhf_router
9
 
10
  # Configure logging for the application
11
  logging.basicConfig(level=logging.INFO)
 
15
 
16
  app = FastAPI(
17
  title="Financial Notes Generator API",
18
+ description="API for generating financial notes, balance sheets, cash flow statements, and P&L reports with RLHF capabilities.",
19
  version="1.0.0"
20
  )
21
+
22
+ # Include RLHF routes
23
+ app.include_router(rlhf_router)
24
  @app.on_event("startup")
25
  async def startup_event():
26
  logger.info("Financial Notes Generator API has started.")
 
86
 
87
 
88
  @router.post("/notes")
89
+ async def notes_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
90
  file_path = f"data/input/{file.filename}"
91
  os.makedirs("data/input", exist_ok=True)
92
  with open(file_path, "wb") as buffer:
93
  shutil.copyfileobj(file.file, buffer)
94
+
95
+ # Choose workflow based on RLHF preference
96
+ if use_rlhf:
97
+ result = run_rlhf_workflow(file_path, "notes")
98
+ else:
99
+ result = run_workflow(file_path, "notes")
100
+
101
  if result["status"] == "success":
102
+ response = FileResponse(result["result"]["output_xlsx_path"], filename=os.path.basename(result["result"]["output_xlsx_path"]))
103
+
104
+ # Add RLHF metadata to headers if available
105
+ if "rlhf_metadata" in result.get("result", {}):
106
+ rlhf_data = result["result"]["rlhf_metadata"]
107
+ response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
108
+ response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
109
+ response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
110
+
111
+ return response
112
  raise HTTPException(status_code=500, detail=result["error"])
113
 
114
  @router.post("/pnl")
115
+ async def pnl_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
116
  file_path = f"data/input/{file.filename}"
117
  os.makedirs("data/input", exist_ok=True)
118
  with open(file_path, "wb") as buffer:
119
  shutil.copyfileobj(file.file, buffer)
120
+
121
+ # Choose workflow based on RLHF preference
122
+ if use_rlhf:
123
+ result = run_rlhf_workflow(file_path, "pnl")
124
+ else:
125
+ result = run_workflow(file_path, "pnl")
126
+
127
  if result["status"] == "success":
128
+ response = FileResponse(result["result"].get("output_path", "data/pnl_statement.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/pnl_statement.xlsx")))
129
+
130
+ # Add RLHF metadata to headers if available
131
+ if "rlhf_metadata" in result.get("result", {}):
132
+ rlhf_data = result["result"]["rlhf_metadata"]
133
+ response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
134
+ response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
135
+ response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
136
+
137
+ return response
138
  raise HTTPException(status_code=500, detail=result["error"])
139
 
140
  @router.post("/bs")
141
+ async def bs_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
142
  file_path = f"data/input/{file.filename}"
143
  os.makedirs("data/input", exist_ok=True)
144
  with open(file_path, "wb") as buffer:
145
  shutil.copyfileobj(file.file, buffer)
146
+
147
+ # Choose workflow based on RLHF preference
148
+ if use_rlhf:
149
+ result = run_rlhf_workflow(file_path, "bs")
150
+ else:
151
+ result = run_workflow(file_path, "bs")
152
+
153
  if result["status"] == "success":
154
  # Use first xlsx file in output dir if present
155
  output_file = result["result"].get("output_path")
 
161
  output_file = os.path.join(output_dir, xlsx_files[0])
162
  else:
163
  raise HTTPException(status_code=500, detail="No balance sheet Excel file produced")
164
+
165
+ response = FileResponse(output_file, filename=os.path.basename(output_file))
166
+
167
+ # Add RLHF metadata to headers if available
168
+ if "rlhf_metadata" in result.get("result", {}):
169
+ rlhf_data = result["result"]["rlhf_metadata"]
170
+ response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
171
+ response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
172
+ response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
173
+
174
+ return response
175
  else:
176
  raise HTTPException(status_code=500, detail=result["error"])
177
 
178
  @router.post("/cf")
179
+ async def cf_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
180
  file_path = f"data/input/{file.filename}"
181
  os.makedirs("data/input", exist_ok=True)
182
  with open(file_path, "wb") as buffer:
183
  shutil.copyfileobj(file.file, buffer)
184
+
185
+ # Choose workflow based on RLHF preference
186
+ if use_rlhf:
187
+ result = run_rlhf_workflow(file_path, "cf")
188
+ else:
189
+ result = run_workflow(file_path, "cf")
190
+
191
  if result["status"] == "success":
192
+ response = FileResponse(result["result"].get("output_path", "data/cash_flow_statements.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/cash_flow_statements.xlsx")))
193
+
194
+ # Add RLHF metadata to headers if available
195
+ if "rlhf_metadata" in result.get("result", {}):
196
+ rlhf_data = result["result"]["rlhf_metadata"]
197
+ response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
198
+ response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
199
+ response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
200
+
201
+ return response
202
  raise HTTPException(status_code=500, detail=result["error"])
203
  app.include_router(router)
204
 
requirements.txt CHANGED
@@ -13,4 +13,11 @@ langchain
13
  langchain-openai
14
  langchain-community
15
  langchain-core
 
 
16
  langgraph
 
 
 
 
 
 
13
  langchain-openai
14
  langchain-community
15
  langchain-core
16
+
17
+ #langgraph
18
  langgraph
19
+
20
+ # RLHF dependencies
21
+ scikit-learn
22
+ numpy
23
+ joblib