Spaces:
Sleeping
Sleeping
Sahil Garg
commited on
Commit
·
c172f37
1
Parent(s):
41cb3f5
initial RLHF applied
Browse files- .gitignore +8 -1
- RLHF_GUIDE.md +167 -0
- agents/feedback_manager.py +248 -0
- agents/reward_model.py +307 -0
- agents/rlhf_routes.py +352 -0
- agents/rlhf_workflows.py +267 -0
- app.py +80 -14
- requirements.txt +7 -0
.gitignore
CHANGED
|
@@ -20,4 +20,11 @@ app/__pycache__/
|
|
| 20 |
pnlbs/__pycache__/
|
| 21 |
AGENT_GUIDE.md
|
| 22 |
docker-compose.dev.yml
|
| 23 |
-
file_cleanup.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
pnlbs/__pycache__/
|
| 21 |
AGENT_GUIDE.md
|
| 22 |
docker-compose.dev.yml
|
| 23 |
+
file_cleanup.py
|
| 24 |
+
agents/langgraph_routes.py
|
| 25 |
+
|
| 26 |
+
# RLHF related data
|
| 27 |
+
data/feedback/
|
| 28 |
+
data/models/
|
| 29 |
+
*.pkl
|
| 30 |
+
*.joblib
|
RLHF_GUIDE.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RLHF (Reinforcement Learning from Human Feedback) Features
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
FinRyver now includes RLHF capabilities that allow the system to learn from human feedback and improve the quality of generated financial statements over time.
|
| 6 |
+
|
| 7 |
+
## Key Components
|
| 8 |
+
|
| 9 |
+
### 1. **Enhanced Workflows**
|
| 10 |
+
- RLHF-enhanced versions of all financial statement generation workflows
|
| 11 |
+
- Multiple candidate generation and selection using reward models
|
| 12 |
+
- Quality prediction and confidence scoring
|
| 13 |
+
|
| 14 |
+
### 2. **Feedback Collection System**
|
| 15 |
+
- Web-based review interface for human feedback
|
| 16 |
+
- Structured feedback forms with technical and quality metrics
|
| 17 |
+
- Storage and management of feedback data
|
| 18 |
+
|
| 19 |
+
### 3. **Reward Model**
|
| 20 |
+
- Machine learning model that predicts statement quality
|
| 21 |
+
- Trained on human feedback data
|
| 22 |
+
- Automatic retraining when sufficient new feedback is available
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
### Basic Financial Statement Generation
|
| 27 |
+
|
| 28 |
+
**Standard workflow (existing functionality):**
|
| 29 |
+
```bash
|
| 30 |
+
curl -X POST "http://localhost:8000/notes" \
|
| 31 |
+
-F "file=@trial_balance.xlsx"
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
**RLHF-enhanced workflow:**
|
| 35 |
+
```bash
|
| 36 |
+
curl -X POST "http://localhost:8000/notes?use_rlhf=true" \
|
| 37 |
+
-F "file=@trial_balance.xlsx"
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
The RLHF-enhanced workflow will:
|
| 41 |
+
1. Generate multiple candidates (if reward model is trained)
|
| 42 |
+
2. Use the reward model to select the best candidate
|
| 43 |
+
3. Provide quality predictions and confidence scores
|
| 44 |
+
4. Store the result for potential human feedback
|
| 45 |
+
|
| 46 |
+
### Response Headers
|
| 47 |
+
|
| 48 |
+
When using RLHF workflows, additional metadata is included in response headers:
|
| 49 |
+
- `X-RLHF-Statement-ID`: Unique ID for the generated statement
|
| 50 |
+
- `X-RLHF-Quality-Score`: Predicted quality score (1-5)
|
| 51 |
+
- `X-RLHF-Confidence`: Model confidence in the prediction
|
| 52 |
+
|
| 53 |
+
### Feedback Collection
|
| 54 |
+
|
| 55 |
+
#### 1. Get Statements Needing Review
|
| 56 |
+
```bash
|
| 57 |
+
curl "http://localhost:8000/rlhf/pending-reviews"
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
#### 2. Review Interface
|
| 61 |
+
Visit: `http://localhost:8000/rlhf/review/{statement_id}`
|
| 62 |
+
|
| 63 |
+
This provides an HTML form for structured feedback collection.
|
| 64 |
+
|
| 65 |
+
#### 3. Submit Feedback Programmatically
|
| 66 |
+
```bash
|
| 67 |
+
curl -X POST "http://localhost:8000/rlhf/feedback" \
|
| 68 |
+
-F "statement_id=123e4567-e89b-12d3-a456-426614174000" \
|
| 69 |
+
-F "calculation_accuracy=4" \
|
| 70 |
+
-F "account_classification=5" \
|
| 71 |
+
-F "statement_balance=4" \
|
| 72 |
+
-F "accounting_standards=4" \
|
| 73 |
+
-F "regulatory_compliance=5" \
|
| 74 |
+
-F "completeness=3" \
|
| 75 |
+
-F "professional_presentation=4" \
|
| 76 |
+
-F "would_accept_for_audit=true" \
|
| 77 |
+
-F "specific_errors=Minor formatting issues" \
|
| 78 |
+
-F "improvement_suggestions=Add more detailed notes"
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Monitoring and Statistics
|
| 82 |
+
|
| 83 |
+
#### Get Feedback Statistics
|
| 84 |
+
```bash
|
| 85 |
+
curl "http://localhost:8000/rlhf/stats"
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
- Total feedback collected
|
| 90 |
+
- Average quality scores
|
| 91 |
+
- Audit approval rates
|
| 92 |
+
- Model training status
|
| 93 |
+
- Feature importance
|
| 94 |
+
|
| 95 |
+
#### Get Model Information
|
| 96 |
+
```bash
|
| 97 |
+
curl "http://localhost:8000/rlhf/model-info"
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
#### Manual Model Retraining
|
| 101 |
+
```bash
|
| 102 |
+
curl -X POST "http://localhost:8000/rlhf/retrain"
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Feedback Metrics
|
| 106 |
+
|
| 107 |
+
### Technical Accuracy (1-5 scale)
|
| 108 |
+
- **Calculation Accuracy**: Mathematical correctness
|
| 109 |
+
- **Account Classification**: Proper categorization of accounts
|
| 110 |
+
- **Statement Balance**: Internal consistency and reconciliation
|
| 111 |
+
|
| 112 |
+
### Compliance (1-5 scale)
|
| 113 |
+
- **Accounting Standards**: GAAP/IFRS compliance
|
| 114 |
+
- **Regulatory Compliance**: Meeting regulatory requirements
|
| 115 |
+
|
| 116 |
+
### Quality (1-5 scale)
|
| 117 |
+
- **Completeness**: All necessary items included
|
| 118 |
+
- **Professional Presentation**: Formatting and language quality
|
| 119 |
+
|
| 120 |
+
### Qualitative Feedback
|
| 121 |
+
- **Specific Errors**: Detailed error descriptions
|
| 122 |
+
- **Missing Items**: Items that should be included
|
| 123 |
+
- **Improvement Suggestions**: Recommendations for enhancement
|
| 124 |
+
- **Audit Acceptance**: Binary approval for professional use
|
| 125 |
+
|
| 126 |
+
## Training Process
|
| 127 |
+
|
| 128 |
+
1. **Initial Phase**: System operates with default models
|
| 129 |
+
2. **Feedback Collection**: Human experts review generated statements
|
| 130 |
+
3. **Model Training**: When 20+ feedback samples are available, reward model is trained
|
| 131 |
+
4. **Enhanced Generation**: RLHF workflows use trained model for better results
|
| 132 |
+
5. **Continuous Learning**: Model retrains automatically with new feedback
|
| 133 |
+
|
| 134 |
+
## Benefits
|
| 135 |
+
|
| 136 |
+
- **Quality Improvement**: Statements become more accurate over time
|
| 137 |
+
- **Domain Adaptation**: System learns specific requirements and preferences
|
| 138 |
+
- **Consistency**: Reduces variability in output quality
|
| 139 |
+
- **Professional Standards**: Aligns with human expert expectations
|
| 140 |
+
|
| 141 |
+
## Implementation Notes
|
| 142 |
+
|
| 143 |
+
- RLHF features are optional and backward-compatible
|
| 144 |
+
- Existing workflows continue to work unchanged
|
| 145 |
+
- Feedback data is stored locally and can be exported for analysis
|
| 146 |
+
- Models can be backed up and restored
|
| 147 |
+
- Multiple reward models can be maintained for different statement types
|
| 148 |
+
|
| 149 |
+
## File Structure
|
| 150 |
+
|
| 151 |
+
```
|
| 152 |
+
data/
|
| 153 |
+
├── feedback/
|
| 154 |
+
│ ├── human_feedback.json # Collected feedback data
|
| 155 |
+
│ └── generated_statements.json # Statement metadata
|
| 156 |
+
└── models/
|
| 157 |
+
├── reward_model.pkl # Trained reward model
|
| 158 |
+
├── feature_names.json # Model feature definitions
|
| 159 |
+
└── model_stats.json # Training statistics
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## Security and Privacy
|
| 163 |
+
|
| 164 |
+
- Feedback data is stored locally
|
| 165 |
+
- No external transmission of financial data
|
| 166 |
+
- Anonymous feedback collection supported
|
| 167 |
+
- Data can be cleaned/anonymized before training
|
agents/feedback_manager.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLHF Feedback Management System for FinRyver
|
| 3 |
+
Handles collection, storage, and management of human feedback on financial statements
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class FeedbackManager:
|
| 15 |
+
"""Manages human feedback collection for RLHF training"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, feedback_dir: str = "data/feedback"):
|
| 18 |
+
self.feedback_dir = feedback_dir
|
| 19 |
+
self.feedback_db = os.path.join(feedback_dir, "human_feedback.json")
|
| 20 |
+
self.statements_db = os.path.join(feedback_dir, "generated_statements.json")
|
| 21 |
+
os.makedirs(feedback_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def store_generated_statement(self, statement_data: Dict[str, Any]) -> str:
|
| 24 |
+
"""Store generated statement for later feedback collection"""
|
| 25 |
+
statement_id = str(uuid.uuid4())
|
| 26 |
+
statement_record = {
|
| 27 |
+
"statement_id": statement_id,
|
| 28 |
+
"timestamp": time.time(),
|
| 29 |
+
"statement_type": statement_data.get("type", "unknown"),
|
| 30 |
+
"file_path": statement_data.get("file_path"),
|
| 31 |
+
"output_path": statement_data.get("output_path"),
|
| 32 |
+
"generation_time": statement_data.get("generation_time", 0),
|
| 33 |
+
"metadata": statement_data.get("metadata", {})
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Load existing statements
|
| 37 |
+
statements = self._load_statements()
|
| 38 |
+
statements.append(statement_record)
|
| 39 |
+
|
| 40 |
+
# Save updated statements
|
| 41 |
+
with open(self.statements_db, "w") as f:
|
| 42 |
+
json.dump(statements, f, indent=2)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Stored statement {statement_id} for feedback collection")
|
| 45 |
+
return statement_id
|
| 46 |
+
|
| 47 |
+
def store_feedback(self, feedback: Dict[str, Any]) -> str:
|
| 48 |
+
"""Store human feedback for RLHF training"""
|
| 49 |
+
feedback_id = str(uuid.uuid4())
|
| 50 |
+
feedback_record = {
|
| 51 |
+
"feedback_id": feedback_id,
|
| 52 |
+
"statement_id": feedback.get("statement_id"),
|
| 53 |
+
"timestamp": time.time(),
|
| 54 |
+
"reviewer_id": feedback.get("reviewer_id", "anonymous"),
|
| 55 |
+
|
| 56 |
+
# Technical accuracy metrics
|
| 57 |
+
"calculation_accuracy": feedback.get("calculation_accuracy"),
|
| 58 |
+
"account_classification": feedback.get("account_classification"),
|
| 59 |
+
"statement_balance": feedback.get("statement_balance"),
|
| 60 |
+
|
| 61 |
+
# Compliance metrics
|
| 62 |
+
"accounting_standards": feedback.get("accounting_standards"),
|
| 63 |
+
"regulatory_compliance": feedback.get("regulatory_compliance"),
|
| 64 |
+
|
| 65 |
+
# Quality metrics
|
| 66 |
+
"completeness": feedback.get("completeness"),
|
| 67 |
+
"professional_presentation": feedback.get("professional_presentation"),
|
| 68 |
+
|
| 69 |
+
# Overall quality score (computed)
|
| 70 |
+
"overall_score": self._compute_overall_score(feedback),
|
| 71 |
+
|
| 72 |
+
# Qualitative feedback
|
| 73 |
+
"specific_errors": feedback.get("specific_errors", ""),
|
| 74 |
+
"missing_items": feedback.get("missing_items", ""),
|
| 75 |
+
"improvement_suggestions": feedback.get("improvement_suggestions", ""),
|
| 76 |
+
"would_accept_for_audit": feedback.get("would_accept_for_audit", False),
|
| 77 |
+
|
| 78 |
+
# Additional context
|
| 79 |
+
"statement_type": feedback.get("statement_type"),
|
| 80 |
+
"complexity_level": feedback.get("complexity_level", "medium")
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Load existing feedback
|
| 84 |
+
all_feedback = self._load_feedback()
|
| 85 |
+
all_feedback.append(feedback_record)
|
| 86 |
+
|
| 87 |
+
# Save updated feedback
|
| 88 |
+
with open(self.feedback_db, "w") as f:
|
| 89 |
+
json.dump(all_feedback, f, indent=2)
|
| 90 |
+
|
| 91 |
+
logger.info(f"Stored feedback {feedback_id} for statement {feedback.get('statement_id')}")
|
| 92 |
+
return feedback_id
|
| 93 |
+
|
| 94 |
+
def get_training_data(self, min_feedback_count: int = 2) -> List[Dict[str, Any]]:
|
| 95 |
+
"""Get feedback data suitable for RLHF training"""
|
| 96 |
+
feedback_data = self._load_feedback()
|
| 97 |
+
|
| 98 |
+
if len(feedback_data) < min_feedback_count:
|
| 99 |
+
logger.warning(f"Only {len(feedback_data)} feedback samples available, need at least {min_feedback_count}")
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
# Filter and prepare training data
|
| 103 |
+
training_data = []
|
| 104 |
+
for feedback in feedback_data:
|
| 105 |
+
if feedback.get("overall_score") is not None:
|
| 106 |
+
training_sample = {
|
| 107 |
+
"statement_id": feedback["statement_id"],
|
| 108 |
+
"statement_type": feedback["statement_type"],
|
| 109 |
+
"reward_score": feedback["overall_score"],
|
| 110 |
+
"binary_approval": feedback["would_accept_for_audit"],
|
| 111 |
+
"technical_metrics": {
|
| 112 |
+
"calculation_accuracy": feedback.get("calculation_accuracy"),
|
| 113 |
+
"account_classification": feedback.get("account_classification"),
|
| 114 |
+
"statement_balance": feedback.get("statement_balance")
|
| 115 |
+
},
|
| 116 |
+
"quality_metrics": {
|
| 117 |
+
"completeness": feedback.get("completeness"),
|
| 118 |
+
"professional_presentation": feedback.get("professional_presentation"),
|
| 119 |
+
"accounting_standards": feedback.get("accounting_standards")
|
| 120 |
+
},
|
| 121 |
+
"feedback_text": {
|
| 122 |
+
"errors": feedback.get("specific_errors", ""),
|
| 123 |
+
"missing": feedback.get("missing_items", ""),
|
| 124 |
+
"suggestions": feedback.get("improvement_suggestions", "")
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
training_data.append(training_sample)
|
| 128 |
+
|
| 129 |
+
return training_data
|
| 130 |
+
|
| 131 |
+
def get_statement_for_review(self, statement_id: str) -> Optional[Dict[str, Any]]:
|
| 132 |
+
"""Get statement data for human review"""
|
| 133 |
+
statements = self._load_statements()
|
| 134 |
+
for statement in statements:
|
| 135 |
+
if statement["statement_id"] == statement_id:
|
| 136 |
+
return statement
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
def get_pending_reviews(self, limit: int = 10) -> List[Dict[str, Any]]:
|
| 140 |
+
"""Get statements that need human review"""
|
| 141 |
+
statements = self._load_statements()
|
| 142 |
+
feedback_data = self._load_feedback()
|
| 143 |
+
|
| 144 |
+
# Get statement IDs that already have feedback
|
| 145 |
+
reviewed_ids = {fb["statement_id"] for fb in feedback_data}
|
| 146 |
+
|
| 147 |
+
# Return statements without feedback
|
| 148 |
+
pending = [s for s in statements if s["statement_id"] not in reviewed_ids]
|
| 149 |
+
return pending[-limit:] # Return most recent
|
| 150 |
+
|
| 151 |
+
def get_feedback_stats(self) -> Dict[str, Any]:
|
| 152 |
+
"""Get statistics about collected feedback"""
|
| 153 |
+
feedback_data = self._load_feedback()
|
| 154 |
+
statements = self._load_statements()
|
| 155 |
+
|
| 156 |
+
if not feedback_data:
|
| 157 |
+
return {"total_feedback": 0, "total_statements": len(statements)}
|
| 158 |
+
|
| 159 |
+
# Calculate statistics
|
| 160 |
+
scores = [fb["overall_score"] for fb in feedback_data if fb.get("overall_score")]
|
| 161 |
+
audit_approvals = [fb["would_accept_for_audit"] for fb in feedback_data]
|
| 162 |
+
|
| 163 |
+
stats = {
|
| 164 |
+
"total_feedback": len(feedback_data),
|
| 165 |
+
"total_statements": len(statements),
|
| 166 |
+
"avg_overall_score": sum(scores) / len(scores) if scores else 0,
|
| 167 |
+
"audit_approval_rate": sum(audit_approvals) / len(audit_approvals) if audit_approvals else 0,
|
| 168 |
+
"feedback_by_type": {},
|
| 169 |
+
"recent_trend": self._calculate_trend()
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Group by statement type
|
| 173 |
+
for fb in feedback_data:
|
| 174 |
+
stmt_type = fb.get("statement_type", "unknown")
|
| 175 |
+
if stmt_type not in stats["feedback_by_type"]:
|
| 176 |
+
stats["feedback_by_type"][stmt_type] = {"count": 0, "avg_score": 0}
|
| 177 |
+
stats["feedback_by_type"][stmt_type]["count"] += 1
|
| 178 |
+
|
| 179 |
+
return stats
|
| 180 |
+
|
| 181 |
+
def _load_feedback(self) -> List[Dict[str, Any]]:
|
| 182 |
+
"""Load feedback from storage"""
|
| 183 |
+
if os.path.exists(self.feedback_db):
|
| 184 |
+
try:
|
| 185 |
+
with open(self.feedback_db, "r") as f:
|
| 186 |
+
return json.load(f)
|
| 187 |
+
except (json.JSONDecodeError, FileNotFoundError):
|
| 188 |
+
logger.warning("Could not load feedback database, starting fresh")
|
| 189 |
+
return []
|
| 190 |
+
|
| 191 |
+
def _load_statements(self) -> List[Dict[str, Any]]:
|
| 192 |
+
"""Load statements from storage"""
|
| 193 |
+
if os.path.exists(self.statements_db):
|
| 194 |
+
try:
|
| 195 |
+
with open(self.statements_db, "r") as f:
|
| 196 |
+
return json.load(f)
|
| 197 |
+
except (json.JSONDecodeError, FileNotFoundError):
|
| 198 |
+
logger.warning("Could not load statements database, starting fresh")
|
| 199 |
+
return []
|
| 200 |
+
|
| 201 |
+
def _compute_overall_score(self, feedback: Dict[str, Any]) -> float:
|
| 202 |
+
"""Compute overall quality score from individual metrics"""
|
| 203 |
+
metrics = [
|
| 204 |
+
feedback.get("calculation_accuracy"),
|
| 205 |
+
feedback.get("account_classification"),
|
| 206 |
+
feedback.get("statement_balance"),
|
| 207 |
+
feedback.get("accounting_standards"),
|
| 208 |
+
feedback.get("regulatory_compliance"),
|
| 209 |
+
feedback.get("completeness"),
|
| 210 |
+
feedback.get("professional_presentation")
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
# Filter out None values
|
| 214 |
+
valid_metrics = [m for m in metrics if m is not None]
|
| 215 |
+
|
| 216 |
+
if not valid_metrics:
|
| 217 |
+
return 0.0
|
| 218 |
+
|
| 219 |
+
return sum(valid_metrics) / len(valid_metrics)
|
| 220 |
+
|
| 221 |
+
def _calculate_trend(self) -> Dict[str, float]:
|
| 222 |
+
"""Calculate recent feedback trend"""
|
| 223 |
+
feedback_data = self._load_feedback()
|
| 224 |
+
|
| 225 |
+
if len(feedback_data) < 5:
|
| 226 |
+
return {"trend": "insufficient_data"}
|
| 227 |
+
|
| 228 |
+
# Sort by timestamp
|
| 229 |
+
sorted_feedback = sorted(feedback_data, key=lambda x: x.get("timestamp", 0))
|
| 230 |
+
|
| 231 |
+
# Compare recent vs older feedback
|
| 232 |
+
mid_point = len(sorted_feedback) // 2
|
| 233 |
+
older_scores = [fb["overall_score"] for fb in sorted_feedback[:mid_point] if fb.get("overall_score")]
|
| 234 |
+
recent_scores = [fb["overall_score"] for fb in sorted_feedback[mid_point:] if fb.get("overall_score")]
|
| 235 |
+
|
| 236 |
+
if older_scores and recent_scores:
|
| 237 |
+
older_avg = sum(older_scores) / len(older_scores)
|
| 238 |
+
recent_avg = sum(recent_scores) / len(recent_scores)
|
| 239 |
+
improvement = recent_avg - older_avg
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"older_average": older_avg,
|
| 243 |
+
"recent_average": recent_avg,
|
| 244 |
+
"improvement": improvement,
|
| 245 |
+
"trend": "improving" if improvement > 0.1 else "stable" if abs(improvement) <= 0.1 else "declining"
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
return {"trend": "insufficient_data"}
|
agents/reward_model.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLHF Reward Model for FinRyver
|
| 3 |
+
Predicts quality scores for generated financial statements based on human feedback
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
from sklearn.metrics import mean_squared_error, r2_score
|
| 13 |
+
import joblib
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class FinancialRewardModel:
|
| 19 |
+
"""
|
| 20 |
+
Reward model that predicts quality scores for financial statements
|
| 21 |
+
Uses traditional ML initially, can be upgraded to transformer-based models
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_dir: str = "data/models"):
|
| 25 |
+
self.model_dir = model_dir
|
| 26 |
+
self.model_path = os.path.join(model_dir, "reward_model.pkl")
|
| 27 |
+
self.feature_names_path = os.path.join(model_dir, "feature_names.json")
|
| 28 |
+
self.model_stats_path = os.path.join(model_dir, "model_stats.json")
|
| 29 |
+
|
| 30 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# Initialize model
|
| 33 |
+
self.model = RandomForestRegressor(
|
| 34 |
+
n_estimators=100,
|
| 35 |
+
max_depth=10,
|
| 36 |
+
random_state=42,
|
| 37 |
+
n_jobs=-1
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.feature_names = []
|
| 41 |
+
self.is_trained = False
|
| 42 |
+
self.model_version = "1.0"
|
| 43 |
+
|
| 44 |
+
# Load existing model if available
|
| 45 |
+
self._load_model()
|
| 46 |
+
|
| 47 |
+
def extract_features(self, statement_data: Dict[str, Any], statement_content: str = "") -> np.ndarray:
|
| 48 |
+
"""Extract features from statement data for reward prediction"""
|
| 49 |
+
features = []
|
| 50 |
+
|
| 51 |
+
# Basic metadata features
|
| 52 |
+
features.append(len(statement_content)) # Content length
|
| 53 |
+
features.append(statement_data.get("generation_time", 0)) # Generation time
|
| 54 |
+
features.append(1 if statement_data.get("statement_type") == "notes" else 0)
|
| 55 |
+
features.append(1 if statement_data.get("statement_type") == "balance_sheet" else 0)
|
| 56 |
+
features.append(1 if statement_data.get("statement_type") == "pnl" else 0)
|
| 57 |
+
features.append(1 if statement_data.get("statement_type") == "cash_flow" else 0)
|
| 58 |
+
|
| 59 |
+
# Content-based features (simple heuristics)
|
| 60 |
+
if statement_content:
|
| 61 |
+
features.append(statement_content.count("$")) # Number of monetary values
|
| 62 |
+
features.append(statement_content.count("\n")) # Number of lines
|
| 63 |
+
features.append(len(statement_content.split())) # Word count
|
| 64 |
+
features.append(statement_content.count(".")) # Number of sentences
|
| 65 |
+
features.append(statement_content.count(",")) # Number of commas (complexity indicator)
|
| 66 |
+
|
| 67 |
+
# Financial keywords
|
| 68 |
+
financial_keywords = ["asset", "liability", "equity", "revenue", "expense", "cash", "account"]
|
| 69 |
+
keyword_count = sum(statement_content.lower().count(keyword) for keyword in financial_keywords)
|
| 70 |
+
features.append(keyword_count)
|
| 71 |
+
|
| 72 |
+
# Professional language indicators
|
| 73 |
+
professional_words = ["accordance", "pursuant", "whereas", "therefore", "respective"]
|
| 74 |
+
professional_count = sum(statement_content.lower().count(word) for word in professional_words)
|
| 75 |
+
features.append(professional_count)
|
| 76 |
+
else:
|
| 77 |
+
# Default values if no content available
|
| 78 |
+
features.extend([0] * 7)
|
| 79 |
+
|
| 80 |
+
# File-based features (if available)
|
| 81 |
+
metadata = statement_data.get("metadata", {})
|
| 82 |
+
features.append(metadata.get("file_size", 0))
|
| 83 |
+
features.append(metadata.get("num_accounts", 0))
|
| 84 |
+
features.append(metadata.get("complexity_score", 0))
|
| 85 |
+
|
| 86 |
+
# Ensure we have consistent feature names
|
| 87 |
+
if not self.feature_names:
|
| 88 |
+
self.feature_names = [
|
| 89 |
+
"content_length", "generation_time", "is_notes", "is_balance_sheet",
|
| 90 |
+
"is_pnl", "is_cash_flow", "monetary_values", "line_count",
|
| 91 |
+
"word_count", "sentence_count", "comma_count", "financial_keywords",
|
| 92 |
+
"professional_words", "file_size", "num_accounts", "complexity_score"
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
return np.array(features).reshape(1, -1)
|
| 96 |
+
|
| 97 |
+
def train_reward_model(self, training_data: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 98 |
+
"""Train reward model from human feedback data"""
|
| 99 |
+
if len(training_data) < 2: # Lowered from 10 to 2 for testing
|
| 100 |
+
logger.warning(f"Insufficient training data: {len(training_data)} samples")
|
| 101 |
+
return {"error": "insufficient_data", "sample_count": len(training_data)}
|
| 102 |
+
|
| 103 |
+
# Prepare training data
|
| 104 |
+
X = []
|
| 105 |
+
y = []
|
| 106 |
+
|
| 107 |
+
for sample in training_data:
|
| 108 |
+
# Create dummy statement data for feature extraction
|
| 109 |
+
statement_data = {
|
| 110 |
+
"statement_type": sample.get("statement_type", "unknown"),
|
| 111 |
+
"generation_time": sample.get("generation_time", 0),
|
| 112 |
+
"metadata": sample.get("metadata", {})
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
# Extract features
|
| 116 |
+
features = self.extract_features(statement_data, "")
|
| 117 |
+
X.append(features.flatten())
|
| 118 |
+
y.append(sample["reward_score"])
|
| 119 |
+
|
| 120 |
+
X = np.array(X)
|
| 121 |
+
y = np.array(y)
|
| 122 |
+
|
| 123 |
+
# Split data
|
| 124 |
+
if len(X) > 20:
|
| 125 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 126 |
+
X, y, test_size=0.2, random_state=42
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
X_train, X_test, y_train, y_test = X, X, y, y
|
| 130 |
+
|
| 131 |
+
# Train model
|
| 132 |
+
logger.info(f"Training reward model with {len(X_train)} samples")
|
| 133 |
+
self.model.fit(X_train, y_train)
|
| 134 |
+
|
| 135 |
+
# Evaluate model
|
| 136 |
+
train_pred = self.model.predict(X_train)
|
| 137 |
+
test_pred = self.model.predict(X_test)
|
| 138 |
+
|
| 139 |
+
metrics = {
|
| 140 |
+
"train_mse": mean_squared_error(y_train, train_pred),
|
| 141 |
+
"test_mse": mean_squared_error(y_test, test_pred),
|
| 142 |
+
"train_r2": r2_score(y_train, train_pred),
|
| 143 |
+
"test_r2": r2_score(y_test, test_pred),
|
| 144 |
+
"sample_count": len(training_data),
|
| 145 |
+
"feature_importance": dict(zip(self.feature_names, self.model.feature_importances_))
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
self.is_trained = True
|
| 149 |
+
|
| 150 |
+
# Save model
|
| 151 |
+
self._save_model(metrics)
|
| 152 |
+
|
| 153 |
+
logger.info(f"Reward model trained. R2 score: {metrics['test_r2']:.3f}")
|
| 154 |
+
return metrics
|
| 155 |
+
|
| 156 |
+
def predict_reward(self, statement_data: Dict[str, Any], statement_content: str = "") -> float:
|
| 157 |
+
"""Predict reward score for a generated financial statement"""
|
| 158 |
+
if not self.is_trained:
|
| 159 |
+
logger.warning("Reward model not trained, returning default score")
|
| 160 |
+
return 3.0 # Default neutral score
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
features = self.extract_features(statement_data, statement_content)
|
| 164 |
+
reward = self.model.predict(features)[0]
|
| 165 |
+
|
| 166 |
+
# Clamp to valid range [1, 5]
|
| 167 |
+
reward = max(1.0, min(5.0, reward))
|
| 168 |
+
|
| 169 |
+
return float(reward)
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Error predicting reward: {e}")
|
| 173 |
+
return 3.0 # Default score on error
|
| 174 |
+
|
| 175 |
+
def predict_with_confidence(self, statement_data: Dict[str, Any], statement_content: str = "") -> Tuple[float, float]:
|
| 176 |
+
"""Predict reward with confidence interval"""
|
| 177 |
+
if not self.is_trained:
|
| 178 |
+
return 3.0, 0.0
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
features = self.extract_features(statement_data, statement_content)
|
| 182 |
+
|
| 183 |
+
# For Random Forest, we can get prediction from all trees
|
| 184 |
+
tree_predictions = [tree.predict(features)[0] for tree in self.model.estimators_]
|
| 185 |
+
|
| 186 |
+
reward = np.mean(tree_predictions)
|
| 187 |
+
confidence = 1.0 / (1.0 + np.std(tree_predictions)) # Higher std = lower confidence
|
| 188 |
+
|
| 189 |
+
reward = max(1.0, min(5.0, reward))
|
| 190 |
+
|
| 191 |
+
return float(reward), float(confidence)
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"Error predicting reward with confidence: {e}")
|
| 195 |
+
return 3.0, 0.0
|
| 196 |
+
|
| 197 |
+
def get_feature_importance(self) -> Dict[str, float]:
|
| 198 |
+
"""Get feature importance from trained model"""
|
| 199 |
+
if not self.is_trained:
|
| 200 |
+
return {}
|
| 201 |
+
|
| 202 |
+
return dict(zip(self.feature_names, self.model.feature_importances_))
|
| 203 |
+
|
| 204 |
+
def get_model_stats(self) -> Dict[str, Any]:
|
| 205 |
+
"""Get model training statistics"""
|
| 206 |
+
if os.path.exists(self.model_stats_path):
|
| 207 |
+
try:
|
| 208 |
+
with open(self.model_stats_path, "r") as f:
|
| 209 |
+
return json.load(f)
|
| 210 |
+
except:
|
| 211 |
+
pass
|
| 212 |
+
return {"status": "not_trained"}
|
| 213 |
+
|
| 214 |
+
def _save_model(self, training_stats: Dict[str, Any]):
|
| 215 |
+
"""Save trained model and metadata"""
|
| 216 |
+
try:
|
| 217 |
+
# Save model
|
| 218 |
+
joblib.dump(self.model, self.model_path)
|
| 219 |
+
|
| 220 |
+
# Save feature names
|
| 221 |
+
with open(self.feature_names_path, "w") as f:
|
| 222 |
+
json.dump(self.feature_names, f)
|
| 223 |
+
|
| 224 |
+
# Save training stats
|
| 225 |
+
stats = {
|
| 226 |
+
"model_version": self.model_version,
|
| 227 |
+
"training_timestamp": time.time(),
|
| 228 |
+
"is_trained": True,
|
| 229 |
+
**training_stats
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
with open(self.model_stats_path, "w") as f:
|
| 233 |
+
json.dump(stats, f, indent=2)
|
| 234 |
+
|
| 235 |
+
logger.info("Reward model saved successfully")
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.error(f"Error saving model: {e}")
|
| 239 |
+
|
| 240 |
+
def _load_model(self):
|
| 241 |
+
"""Load existing trained model"""
|
| 242 |
+
try:
|
| 243 |
+
if os.path.exists(self.model_path) and os.path.exists(self.feature_names_path):
|
| 244 |
+
self.model = joblib.load(self.model_path)
|
| 245 |
+
|
| 246 |
+
with open(self.feature_names_path, "r") as f:
|
| 247 |
+
self.feature_names = json.load(f)
|
| 248 |
+
|
| 249 |
+
self.is_trained = True
|
| 250 |
+
logger.info("Existing reward model loaded successfully")
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.warning(f"Could not load existing model: {e}")
|
| 254 |
+
self.is_trained = False
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class RLHFTrainer:
|
| 258 |
+
"""Coordinates RLHF training pipeline"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, feedback_manager, reward_model):
|
| 261 |
+
self.feedback_manager = feedback_manager
|
| 262 |
+
self.reward_model = reward_model
|
| 263 |
+
self.min_feedback_threshold = 2 # Lowered for testing (was 20)
|
| 264 |
+
|
| 265 |
+
def should_retrain(self) -> bool:
|
| 266 |
+
"""Determine if model should be retrained"""
|
| 267 |
+
stats = self.feedback_manager.get_feedback_stats()
|
| 268 |
+
|
| 269 |
+
# Check if we have enough new feedback
|
| 270 |
+
total_feedback = stats.get("total_feedback", 0)
|
| 271 |
+
|
| 272 |
+
# Get last training count
|
| 273 |
+
model_stats = self.reward_model.get_model_stats()
|
| 274 |
+
last_training_count = model_stats.get("sample_count", 0)
|
| 275 |
+
|
| 276 |
+
new_feedback_count = total_feedback - last_training_count
|
| 277 |
+
|
| 278 |
+
return (total_feedback >= self.min_feedback_threshold and
|
| 279 |
+
new_feedback_count >= 2) # At least 2 new samples (was 10)
|
| 280 |
+
|
| 281 |
+
def retrain_model(self) -> Dict[str, Any]:
|
| 282 |
+
"""Retrain reward model with latest feedback"""
|
| 283 |
+
training_data = self.feedback_manager.get_training_data()
|
| 284 |
+
|
| 285 |
+
if len(training_data) < self.min_feedback_threshold:
|
| 286 |
+
return {
|
| 287 |
+
"status": "insufficient_data",
|
| 288 |
+
"current_count": len(training_data),
|
| 289 |
+
"required_count": self.min_feedback_threshold
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# Train model
|
| 293 |
+
metrics = self.reward_model.train_reward_model(training_data)
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
"status": "success",
|
| 297 |
+
"training_metrics": metrics,
|
| 298 |
+
"timestamp": time.time()
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def periodic_training_check(self) -> Dict[str, Any]:
|
| 302 |
+
"""Check if retraining is needed and perform if necessary"""
|
| 303 |
+
if self.should_retrain():
|
| 304 |
+
logger.info("Initiating automatic model retraining")
|
| 305 |
+
return self.retrain_model()
|
| 306 |
+
else:
|
| 307 |
+
return {"status": "no_retraining_needed"}
|
agents/rlhf_routes.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLHF Feedback Collection Routes for FinRyver
|
| 3 |
+
Handles human feedback collection for financial statement quality
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import APIRouter, HTTPException, Form, Query, Request
|
| 6 |
+
from fastapi.responses import JSONResponse, HTMLResponse
|
| 7 |
+
from typing import Optional, Dict, Any
|
| 8 |
+
import logging
|
| 9 |
+
from agents.feedback_manager import FeedbackManager
|
| 10 |
+
from agents.reward_model import FinancialRewardModel, RLHFTrainer
|
| 11 |
+
from agents.rlhf_workflows import get_rlhf_manager
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Create RLHF router
|
| 16 |
+
rlhf_router = APIRouter(prefix="/rlhf", tags=["RLHF Feedback"])
|
| 17 |
+
|
| 18 |
+
# Initialize components
|
| 19 |
+
feedback_manager = FeedbackManager()
|
| 20 |
+
reward_model = FinancialRewardModel()
|
| 21 |
+
trainer = RLHFTrainer(feedback_manager, reward_model)
|
| 22 |
+
|
| 23 |
+
@rlhf_router.post("/feedback")
|
| 24 |
+
async def collect_feedback(
|
| 25 |
+
statement_id: str = Form(...),
|
| 26 |
+
reviewer_id: str = Form("anonymous"),
|
| 27 |
+
|
| 28 |
+
# Technical accuracy metrics (1-5 scale)
|
| 29 |
+
calculation_accuracy: float = Form(..., ge=1, le=5),
|
| 30 |
+
account_classification: float = Form(..., ge=1, le=5),
|
| 31 |
+
statement_balance: float = Form(..., ge=1, le=5),
|
| 32 |
+
|
| 33 |
+
# Compliance metrics (1-5 scale)
|
| 34 |
+
accounting_standards: float = Form(..., ge=1, le=5),
|
| 35 |
+
regulatory_compliance: float = Form(..., ge=1, le=5),
|
| 36 |
+
|
| 37 |
+
# Quality metrics (1-5 scale)
|
| 38 |
+
completeness: float = Form(..., ge=1, le=5),
|
| 39 |
+
professional_presentation: float = Form(..., ge=1, le=5),
|
| 40 |
+
|
| 41 |
+
# Qualitative feedback
|
| 42 |
+
specific_errors: str = Form(""),
|
| 43 |
+
missing_items: str = Form(""),
|
| 44 |
+
improvement_suggestions: str = Form(""),
|
| 45 |
+
|
| 46 |
+
# Binary approval
|
| 47 |
+
would_accept_for_audit: bool = Form(False),
|
| 48 |
+
|
| 49 |
+
# Additional context
|
| 50 |
+
complexity_level: str = Form("medium") # low, medium, high
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Collect detailed human feedback on generated financial statements
|
| 54 |
+
This feedback is used to train and improve the AI models
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
# Get statement info
|
| 58 |
+
statement_info = feedback_manager.get_statement_for_review(statement_id)
|
| 59 |
+
if not statement_info:
|
| 60 |
+
raise HTTPException(status_code=404, detail="Statement not found")
|
| 61 |
+
|
| 62 |
+
# Prepare feedback data
|
| 63 |
+
feedback_data = {
|
| 64 |
+
"statement_id": statement_id,
|
| 65 |
+
"reviewer_id": reviewer_id,
|
| 66 |
+
"calculation_accuracy": calculation_accuracy,
|
| 67 |
+
"account_classification": account_classification,
|
| 68 |
+
"statement_balance": statement_balance,
|
| 69 |
+
"accounting_standards": accounting_standards,
|
| 70 |
+
"regulatory_compliance": regulatory_compliance,
|
| 71 |
+
"completeness": completeness,
|
| 72 |
+
"professional_presentation": professional_presentation,
|
| 73 |
+
"specific_errors": specific_errors,
|
| 74 |
+
"missing_items": missing_items,
|
| 75 |
+
"improvement_suggestions": improvement_suggestions,
|
| 76 |
+
"would_accept_for_audit": would_accept_for_audit,
|
| 77 |
+
"statement_type": statement_info.get("statement_type"),
|
| 78 |
+
"complexity_level": complexity_level
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Store feedback
|
| 82 |
+
feedback_id = feedback_manager.store_feedback(feedback_data)
|
| 83 |
+
|
| 84 |
+
# Check if model should be retrained
|
| 85 |
+
retrain_result = trainer.periodic_training_check()
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"status": "success",
|
| 89 |
+
"feedback_id": feedback_id,
|
| 90 |
+
"message": "Feedback collected successfully",
|
| 91 |
+
"model_retrain_status": retrain_result.get("status"),
|
| 92 |
+
"overall_score": feedback_manager._compute_overall_score(feedback_data)
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Error collecting feedback: {e}")
|
| 97 |
+
raise HTTPException(status_code=500, detail=f"Error collecting feedback: {str(e)}")
|
| 98 |
+
|
| 99 |
+
@rlhf_router.get("/review/{statement_id}")
|
| 100 |
+
async def get_review_interface(statement_id: str):
|
| 101 |
+
"""
|
| 102 |
+
Get a review interface for human feedback collection
|
| 103 |
+
Returns HTML form for statement review
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
statement_info = feedback_manager.get_statement_for_review(statement_id)
|
| 107 |
+
if not statement_info:
|
| 108 |
+
raise HTTPException(status_code=404, detail="Statement not found")
|
| 109 |
+
|
| 110 |
+
# Generate HTML review form
|
| 111 |
+
html_content = generate_review_html(statement_id, statement_info)
|
| 112 |
+
return HTMLResponse(content=html_content)
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Error getting review interface: {e}")
|
| 116 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 117 |
+
|
| 118 |
+
@rlhf_router.get("/pending-reviews")
|
| 119 |
+
async def get_pending_reviews(limit: int = Query(10, ge=1, le=50)):
|
| 120 |
+
"""
|
| 121 |
+
Get statements that need human review
|
| 122 |
+
"""
|
| 123 |
+
try:
|
| 124 |
+
pending_statements = feedback_manager.get_pending_reviews(limit)
|
| 125 |
+
return {
|
| 126 |
+
"status": "success",
|
| 127 |
+
"pending_reviews": pending_statements,
|
| 128 |
+
"count": len(pending_statements)
|
| 129 |
+
}
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error getting pending reviews: {e}")
|
| 132 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 133 |
+
|
| 134 |
+
@rlhf_router.get("/stats")
|
| 135 |
+
async def get_feedback_stats():
|
| 136 |
+
"""
|
| 137 |
+
Get feedback and model training statistics
|
| 138 |
+
"""
|
| 139 |
+
try:
|
| 140 |
+
feedback_stats = feedback_manager.get_feedback_stats()
|
| 141 |
+
model_stats = reward_model.get_model_stats()
|
| 142 |
+
feature_importance = reward_model.get_feature_importance()
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"status": "success",
|
| 146 |
+
"feedback_stats": feedback_stats,
|
| 147 |
+
"model_stats": model_stats,
|
| 148 |
+
"feature_importance": feature_importance,
|
| 149 |
+
"model_trained": reward_model.is_trained
|
| 150 |
+
}
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Error getting stats: {e}")
|
| 153 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 154 |
+
|
| 155 |
+
@rlhf_router.post("/retrain")
|
| 156 |
+
async def manual_retrain():
|
| 157 |
+
"""
|
| 158 |
+
Manually trigger model retraining
|
| 159 |
+
"""
|
| 160 |
+
try:
|
| 161 |
+
result = trainer.retrain_model()
|
| 162 |
+
return {
|
| 163 |
+
"status": "success",
|
| 164 |
+
"retrain_result": result
|
| 165 |
+
}
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Error during manual retrain: {e}")
|
| 168 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 169 |
+
|
| 170 |
+
@rlhf_router.get("/model-info")
|
| 171 |
+
async def get_model_info():
|
| 172 |
+
"""
|
| 173 |
+
Get information about the current reward model
|
| 174 |
+
"""
|
| 175 |
+
try:
|
| 176 |
+
return {
|
| 177 |
+
"status": "success",
|
| 178 |
+
"model_trained": reward_model.is_trained,
|
| 179 |
+
"model_version": reward_model.model_version,
|
| 180 |
+
"feature_count": len(reward_model.feature_names),
|
| 181 |
+
"feature_names": reward_model.feature_names,
|
| 182 |
+
"model_stats": reward_model.get_model_stats()
|
| 183 |
+
}
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.error(f"Error getting model info: {e}")
|
| 186 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 187 |
+
|
| 188 |
+
def generate_review_html(statement_id: str, statement_info: Dict) -> str:
|
| 189 |
+
"""Generate HTML form for statement review"""
|
| 190 |
+
return f"""
|
| 191 |
+
<!DOCTYPE html>
|
| 192 |
+
<html>
|
| 193 |
+
<head>
|
| 194 |
+
<title>FinRyver - Statement Review</title>
|
| 195 |
+
<style>
|
| 196 |
+
body {{ font-family: Arial, sans-serif; margin: 40px; }}
|
| 197 |
+
.form-group {{ margin: 15px 0; }}
|
| 198 |
+
label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
| 199 |
+
input, select, textarea {{ width: 100%; padding: 8px; margin-bottom: 10px; }}
|
| 200 |
+
.rating {{ display: flex; gap: 10px; }}
|
| 201 |
+
.rating input {{ width: auto; }}
|
| 202 |
+
button {{ background-color: #007bff; color: white; padding: 10px 20px; border: none; cursor: pointer; }}
|
| 203 |
+
.statement-info {{ background-color: #f8f9fa; padding: 15px; margin-bottom: 20px; border-radius: 5px; }}
|
| 204 |
+
</style>
|
| 205 |
+
</head>
|
| 206 |
+
<body>
|
| 207 |
+
<h1>Financial Statement Review</h1>
|
| 208 |
+
|
| 209 |
+
<div class="statement-info">
|
| 210 |
+
<h3>Statement Information</h3>
|
| 211 |
+
<p><strong>Statement ID:</strong> {statement_id}</p>
|
| 212 |
+
<p><strong>Type:</strong> {statement_info.get('statement_type', 'Unknown')}</p>
|
| 213 |
+
<p><strong>Generated:</strong> {statement_info.get('timestamp', 'Unknown')}</p>
|
| 214 |
+
<p><strong>File:</strong> {statement_info.get('file_path', 'Unknown')}</p>
|
| 215 |
+
</div>
|
| 216 |
+
|
| 217 |
+
<form action="/rlhf/feedback" method="post">
|
| 218 |
+
<input type="hidden" name="statement_id" value="{statement_id}">
|
| 219 |
+
|
| 220 |
+
<div class="form-group">
|
| 221 |
+
<label>Reviewer ID (optional):</label>
|
| 222 |
+
<input type="text" name="reviewer_id" placeholder="Enter your identifier">
|
| 223 |
+
</div>
|
| 224 |
+
|
| 225 |
+
<h3>Technical Accuracy (1-5 scale)</h3>
|
| 226 |
+
|
| 227 |
+
<div class="form-group">
|
| 228 |
+
<label>Calculation Accuracy:</label>
|
| 229 |
+
<select name="calculation_accuracy" required>
|
| 230 |
+
<option value="">Select rating</option>
|
| 231 |
+
<option value="1">1 - Major calculation errors</option>
|
| 232 |
+
<option value="2">2 - Some calculation errors</option>
|
| 233 |
+
<option value="3">3 - Minor calculation issues</option>
|
| 234 |
+
<option value="4">4 - Mostly accurate calculations</option>
|
| 235 |
+
<option value="5">5 - All calculations correct</option>
|
| 236 |
+
</select>
|
| 237 |
+
</div>
|
| 238 |
+
|
| 239 |
+
<div class="form-group">
|
| 240 |
+
<label>Account Classification:</label>
|
| 241 |
+
<select name="account_classification" required>
|
| 242 |
+
<option value="">Select rating</option>
|
| 243 |
+
<option value="1">1 - Major classification errors</option>
|
| 244 |
+
<option value="2">2 - Some classification errors</option>
|
| 245 |
+
<option value="3">3 - Minor classification issues</option>
|
| 246 |
+
<option value="4">4 - Mostly correct classification</option>
|
| 247 |
+
<option value="5">5 - Perfect classification</option>
|
| 248 |
+
</select>
|
| 249 |
+
</div>
|
| 250 |
+
|
| 251 |
+
<div class="form-group">
|
| 252 |
+
<label>Statement Balance/Reconciliation:</label>
|
| 253 |
+
<select name="statement_balance" required>
|
| 254 |
+
<option value="">Select rating</option>
|
| 255 |
+
<option value="1">1 - Does not balance</option>
|
| 256 |
+
<option value="2">2 - Major balance issues</option>
|
| 257 |
+
<option value="3">3 - Minor balance issues</option>
|
| 258 |
+
<option value="4">4 - Mostly balanced</option>
|
| 259 |
+
<option value="5">5 - Perfect balance</option>
|
| 260 |
+
</select>
|
| 261 |
+
</div>
|
| 262 |
+
|
| 263 |
+
<h3>Compliance & Standards (1-5 scale)</h3>
|
| 264 |
+
|
| 265 |
+
<div class="form-group">
|
| 266 |
+
<label>Accounting Standards Compliance:</label>
|
| 267 |
+
<select name="accounting_standards" required>
|
| 268 |
+
<option value="">Select rating</option>
|
| 269 |
+
<option value="1">1 - Major compliance issues</option>
|
| 270 |
+
<option value="2">2 - Some compliance issues</option>
|
| 271 |
+
<option value="3">3 - Minor compliance issues</option>
|
| 272 |
+
<option value="4">4 - Mostly compliant</option>
|
| 273 |
+
<option value="5">5 - Fully compliant</option>
|
| 274 |
+
</select>
|
| 275 |
+
</div>
|
| 276 |
+
|
| 277 |
+
<div class="form-group">
|
| 278 |
+
<label>Regulatory Compliance:</label>
|
| 279 |
+
<select name="regulatory_compliance" required>
|
| 280 |
+
<option value="">Select rating</option>
|
| 281 |
+
<option value="1">1 - Major regulatory issues</option>
|
| 282 |
+
<option value="2">2 - Some regulatory issues</option>
|
| 283 |
+
<option value="3">3 - Minor regulatory issues</option>
|
| 284 |
+
<option value="4">4 - Mostly compliant</option>
|
| 285 |
+
<option value="5">5 - Fully compliant</option>
|
| 286 |
+
</select>
|
| 287 |
+
</div>
|
| 288 |
+
|
| 289 |
+
<h3>Quality & Presentation (1-5 scale)</h3>
|
| 290 |
+
|
| 291 |
+
<div class="form-group">
|
| 292 |
+
<label>Completeness:</label>
|
| 293 |
+
<select name="completeness" required>
|
| 294 |
+
<option value="">Select rating</option>
|
| 295 |
+
<option value="1">1 - Major items missing</option>
|
| 296 |
+
<option value="2">2 - Some items missing</option>
|
| 297 |
+
<option value="3">3 - Minor items missing</option>
|
| 298 |
+
<option value="4">4 - Mostly complete</option>
|
| 299 |
+
<option value="5">5 - Complete</option>
|
| 300 |
+
</select>
|
| 301 |
+
</div>
|
| 302 |
+
|
| 303 |
+
<div class="form-group">
|
| 304 |
+
<label>Professional Presentation:</label>
|
| 305 |
+
<select name="professional_presentation" required>
|
| 306 |
+
<option value="">Select rating</option>
|
| 307 |
+
<option value="1">1 - Unprofessional</option>
|
| 308 |
+
<option value="2">2 - Below standard</option>
|
| 309 |
+
<option value="3">3 - Adequate</option>
|
| 310 |
+
<option value="4">4 - Good presentation</option>
|
| 311 |
+
<option value="5">5 - Excellent presentation</option>
|
| 312 |
+
</select>
|
| 313 |
+
</div>
|
| 314 |
+
|
| 315 |
+
<h3>Detailed Feedback</h3>
|
| 316 |
+
|
| 317 |
+
<div class="form-group">
|
| 318 |
+
<label>Specific Errors (if any):</label>
|
| 319 |
+
<textarea name="specific_errors" rows="3" placeholder="Describe any specific errors found..."></textarea>
|
| 320 |
+
</div>
|
| 321 |
+
|
| 322 |
+
<div class="form-group">
|
| 323 |
+
<label>Missing Items (if any):</label>
|
| 324 |
+
<textarea name="missing_items" rows="3" placeholder="List any missing items or information..."></textarea>
|
| 325 |
+
</div>
|
| 326 |
+
|
| 327 |
+
<div class="form-group">
|
| 328 |
+
<label>Improvement Suggestions:</label>
|
| 329 |
+
<textarea name="improvement_suggestions" rows="3" placeholder="Suggest improvements..."></textarea>
|
| 330 |
+
</div>
|
| 331 |
+
|
| 332 |
+
<div class="form-group">
|
| 333 |
+
<label>Complexity Level:</label>
|
| 334 |
+
<select name="complexity_level">
|
| 335 |
+
<option value="low">Low</option>
|
| 336 |
+
<option value="medium" selected>Medium</option>
|
| 337 |
+
<option value="high">High</option>
|
| 338 |
+
</select>
|
| 339 |
+
</div>
|
| 340 |
+
|
| 341 |
+
<div class="form-group">
|
| 342 |
+
<label>
|
| 343 |
+
<input type="checkbox" name="would_accept_for_audit" value="true">
|
| 344 |
+
Would accept this statement for audit/compliance purposes
|
| 345 |
+
</label>
|
| 346 |
+
</div>
|
| 347 |
+
|
| 348 |
+
<button type="submit">Submit Feedback</button>
|
| 349 |
+
</form>
|
| 350 |
+
</body>
|
| 351 |
+
</html>
|
| 352 |
+
"""
|
agents/rlhf_workflows.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLHF-Enhanced LangGraph Workflows for FinRyver
|
| 3 |
+
Integrates reward model and feedback collection into existing workflows
|
| 4 |
+
"""
|
| 5 |
+
from typing import TypedDict, Dict, Any, List, Annotated, Optional
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
from langgraph.graph import StateGraph, END
|
| 11 |
+
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
| 12 |
+
|
| 13 |
+
# Import existing tools and RLHF components
|
| 14 |
+
from agents.simple_tools import (
|
| 15 |
+
generate_notes_full_pipeline_from_path,
|
| 16 |
+
generate_balance_sheet,
|
| 17 |
+
generate_pnl_statement,
|
| 18 |
+
generate_cash_flow_statement,
|
| 19 |
+
)
|
| 20 |
+
from agents.feedback_manager import FeedbackManager
|
| 21 |
+
from agents.reward_model import FinancialRewardModel, RLHFTrainer
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
class RLHFFinancialAgentState(TypedDict):
|
| 26 |
+
"""Enhanced state with RLHF capabilities"""
|
| 27 |
+
messages: Annotated[List[BaseMessage], "History"]
|
| 28 |
+
file_path: str
|
| 29 |
+
result: Dict[str, Any]
|
| 30 |
+
status: str
|
| 31 |
+
start_time: float
|
| 32 |
+
end_time: float
|
| 33 |
+
error: str
|
| 34 |
+
|
| 35 |
+
# RLHF-specific fields
|
| 36 |
+
statement_id: Optional[str]
|
| 37 |
+
predicted_quality: Optional[float]
|
| 38 |
+
confidence_score: Optional[float]
|
| 39 |
+
candidates_generated: Optional[List[Dict[str, Any]]]
|
| 40 |
+
best_candidate_index: Optional[int]
|
| 41 |
+
feedback_collected: Optional[bool]
|
| 42 |
+
|
| 43 |
+
class RLHFWorkflowManager:
|
| 44 |
+
"""Manages RLHF-enhanced workflows"""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
self.feedback_manager = FeedbackManager()
|
| 48 |
+
self.reward_model = FinancialRewardModel()
|
| 49 |
+
self.trainer = RLHFTrainer(self.feedback_manager, self.reward_model)
|
| 50 |
+
|
| 51 |
+
# Check for model retraining on initialization
|
| 52 |
+
self._check_and_retrain()
|
| 53 |
+
|
| 54 |
+
def _check_and_retrain(self):
|
| 55 |
+
"""Check if model needs retraining"""
|
| 56 |
+
try:
|
| 57 |
+
result = self.trainer.periodic_training_check()
|
| 58 |
+
if result.get("status") == "success":
|
| 59 |
+
logger.info("Reward model retrained successfully")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.error(f"Error during model retraining check: {e}")
|
| 62 |
+
|
| 63 |
+
def make_rlhf_workflow(self, tool_func, statement_type: str):
|
| 64 |
+
"""Create RLHF-enhanced workflow"""
|
| 65 |
+
|
| 66 |
+
def rlhf_node(state: RLHFFinancialAgentState) -> RLHFFinancialAgentState:
|
| 67 |
+
state["start_time"] = time.time()
|
| 68 |
+
state["statement_id"] = str(uuid.uuid4())
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Generate multiple candidates if reward model is trained
|
| 72 |
+
if self.reward_model.is_trained:
|
| 73 |
+
candidates = self._generate_candidates(tool_func, state, num_candidates=3)
|
| 74 |
+
state["candidates_generated"] = candidates
|
| 75 |
+
|
| 76 |
+
# Select best candidate using reward model
|
| 77 |
+
best_candidate, best_index = self._select_best_candidate(
|
| 78 |
+
candidates, statement_type, state["file_path"]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
state["result"] = best_candidate
|
| 82 |
+
state["best_candidate_index"] = best_index
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
# Single generation if no trained model
|
| 86 |
+
result = tool_func.invoke({"file_path": state["file_path"]})
|
| 87 |
+
state["result"] = result
|
| 88 |
+
state["candidates_generated"] = [result]
|
| 89 |
+
state["best_candidate_index"] = 0
|
| 90 |
+
|
| 91 |
+
# Predict quality score
|
| 92 |
+
if state["result"].get("status") == "success":
|
| 93 |
+
predicted_quality, confidence = self._predict_quality(
|
| 94 |
+
state["result"], statement_type, state["file_path"]
|
| 95 |
+
)
|
| 96 |
+
state["predicted_quality"] = predicted_quality
|
| 97 |
+
state["confidence_score"] = confidence
|
| 98 |
+
state["status"] = "success"
|
| 99 |
+
|
| 100 |
+
# Store statement for potential feedback
|
| 101 |
+
self._store_for_feedback(state, statement_type)
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
state["status"] = "error"
|
| 105 |
+
state["error"] = state["result"].get("error", "Unknown error")
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
state["status"] = "error"
|
| 109 |
+
state["error"] = str(e)
|
| 110 |
+
logger.error(f"Error in RLHF workflow: {e}")
|
| 111 |
+
|
| 112 |
+
state["end_time"] = time.time()
|
| 113 |
+
return state
|
| 114 |
+
|
| 115 |
+
# Create workflow graph
|
| 116 |
+
wf = StateGraph(RLHFFinancialAgentState)
|
| 117 |
+
wf.add_node("rlhf_run", rlhf_node)
|
| 118 |
+
wf.set_entry_point("rlhf_run")
|
| 119 |
+
wf.add_edge("rlhf_run", END)
|
| 120 |
+
return wf.compile()
|
| 121 |
+
|
| 122 |
+
def _generate_candidates(self, tool_func, state: RLHFFinancialAgentState, num_candidates: int = 3) -> List[Dict[str, Any]]:
|
| 123 |
+
"""Generate multiple candidates for comparison"""
|
| 124 |
+
candidates = []
|
| 125 |
+
|
| 126 |
+
for i in range(num_candidates):
|
| 127 |
+
try:
|
| 128 |
+
result = tool_func.invoke({"file_path": state["file_path"]})
|
| 129 |
+
candidates.append({
|
| 130 |
+
"index": i,
|
| 131 |
+
"result": result,
|
| 132 |
+
"timestamp": time.time()
|
| 133 |
+
})
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.warning(f"Failed to generate candidate {i}: {e}")
|
| 136 |
+
candidates.append({
|
| 137 |
+
"index": i,
|
| 138 |
+
"result": {"status": "error", "error": str(e)},
|
| 139 |
+
"timestamp": time.time()
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
return candidates
|
| 143 |
+
|
| 144 |
+
def _select_best_candidate(self, candidates: List[Dict[str, Any]], statement_type: str, file_path: str) -> tuple:
|
| 145 |
+
"""Select best candidate using reward model"""
|
| 146 |
+
best_candidate = None
|
| 147 |
+
best_score = -1
|
| 148 |
+
best_index = 0
|
| 149 |
+
|
| 150 |
+
for candidate in candidates:
|
| 151 |
+
if candidate["result"].get("status") == "success":
|
| 152 |
+
# Create statement data for reward prediction
|
| 153 |
+
statement_data = {
|
| 154 |
+
"statement_type": statement_type,
|
| 155 |
+
"file_path": file_path,
|
| 156 |
+
"generation_time": 0, # Could be calculated from timestamps
|
| 157 |
+
"metadata": {}
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Predict reward
|
| 161 |
+
predicted_reward, confidence = self.reward_model.predict_with_confidence(
|
| 162 |
+
statement_data, ""
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Weight by confidence
|
| 166 |
+
weighted_score = predicted_reward * confidence
|
| 167 |
+
|
| 168 |
+
if weighted_score > best_score:
|
| 169 |
+
best_score = weighted_score
|
| 170 |
+
best_candidate = candidate["result"]
|
| 171 |
+
best_index = candidate["index"]
|
| 172 |
+
|
| 173 |
+
# Fallback to first successful candidate
|
| 174 |
+
if best_candidate is None:
|
| 175 |
+
for candidate in candidates:
|
| 176 |
+
if candidate["result"].get("status") == "success":
|
| 177 |
+
best_candidate = candidate["result"]
|
| 178 |
+
best_index = candidate["index"]
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
# Final fallback
|
| 182 |
+
if best_candidate is None and candidates:
|
| 183 |
+
best_candidate = candidates[0]["result"]
|
| 184 |
+
best_index = 0
|
| 185 |
+
|
| 186 |
+
return best_candidate, best_index
|
| 187 |
+
|
| 188 |
+
def _predict_quality(self, result: Dict[str, Any], statement_type: str, file_path: str) -> tuple:
|
| 189 |
+
"""Predict quality score for generated statement"""
|
| 190 |
+
statement_data = {
|
| 191 |
+
"statement_type": statement_type,
|
| 192 |
+
"file_path": file_path,
|
| 193 |
+
"generation_time": 0,
|
| 194 |
+
"metadata": {}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return self.reward_model.predict_with_confidence(statement_data, "")
|
| 198 |
+
|
| 199 |
+
def _store_for_feedback(self, state: RLHFFinancialAgentState, statement_type: str):
|
| 200 |
+
"""Store generated statement for feedback collection"""
|
| 201 |
+
try:
|
| 202 |
+
statement_data = {
|
| 203 |
+
"type": statement_type,
|
| 204 |
+
"file_path": state["file_path"],
|
| 205 |
+
"output_path": state["result"].get("output_path"),
|
| 206 |
+
"generation_time": state["end_time"] - state["start_time"],
|
| 207 |
+
"predicted_quality": state.get("predicted_quality"),
|
| 208 |
+
"confidence_score": state.get("confidence_score"),
|
| 209 |
+
"metadata": {
|
| 210 |
+
"candidates_count": len(state.get("candidates_generated", [])),
|
| 211 |
+
"best_candidate_index": state.get("best_candidate_index"),
|
| 212 |
+
"workflow_version": "rlhf_v1"
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
stored_id = self.feedback_manager.store_generated_statement(statement_data)
|
| 217 |
+
state["statement_id"] = stored_id
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"Error storing statement for feedback: {e}")
|
| 221 |
+
|
| 222 |
+
# Global RLHF manager instance
|
| 223 |
+
rlhf_manager = RLHFWorkflowManager()
|
| 224 |
+
|
| 225 |
+
# RLHF-enhanced workflows
|
| 226 |
+
rlhf_workflows = {
|
| 227 |
+
"notes": rlhf_manager.make_rlhf_workflow(generate_notes_full_pipeline_from_path, "notes"),
|
| 228 |
+
"pnl": rlhf_manager.make_rlhf_workflow(generate_pnl_statement, "pnl"),
|
| 229 |
+
"bs": rlhf_manager.make_rlhf_workflow(generate_balance_sheet, "balance_sheet"),
|
| 230 |
+
"cf": rlhf_manager.make_rlhf_workflow(generate_cash_flow_statement, "cash_flow"),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
def run_rlhf_workflow(file_path: str, kind: str) -> Dict[str, Any]:
|
| 234 |
+
"""Run RLHF-enhanced workflow"""
|
| 235 |
+
state = RLHFFinancialAgentState(
|
| 236 |
+
messages=[HumanMessage(content=f"Run RLHF {kind} for {file_path}")],
|
| 237 |
+
file_path=file_path,
|
| 238 |
+
result={},
|
| 239 |
+
status="",
|
| 240 |
+
start_time=0,
|
| 241 |
+
end_time=0,
|
| 242 |
+
error="",
|
| 243 |
+
statement_id=None,
|
| 244 |
+
predicted_quality=None,
|
| 245 |
+
confidence_score=None,
|
| 246 |
+
candidates_generated=None,
|
| 247 |
+
best_candidate_index=None,
|
| 248 |
+
feedback_collected=False
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
final_state = rlhf_workflows[kind].invoke(state)
|
| 252 |
+
|
| 253 |
+
# Add RLHF metadata to result
|
| 254 |
+
if final_state["status"] == "success":
|
| 255 |
+
final_state["result"]["rlhf_metadata"] = {
|
| 256 |
+
"statement_id": final_state.get("statement_id"),
|
| 257 |
+
"predicted_quality": final_state.get("predicted_quality"),
|
| 258 |
+
"confidence_score": final_state.get("confidence_score"),
|
| 259 |
+
"candidates_generated": len(final_state.get("candidates_generated", [])),
|
| 260 |
+
"model_used": "rlhf_enhanced"
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
return final_state
|
| 264 |
+
|
| 265 |
+
def get_rlhf_manager() -> RLHFWorkflowManager:
|
| 266 |
+
"""Get global RLHF manager instance"""
|
| 267 |
+
return rlhf_manager
|
app.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
-
from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException
|
| 2 |
from fastapi.responses import FileResponse
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import logging
|
| 6 |
from agents.langgraph import run_workflow
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Configure logging for the application
|
| 9 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -13,9 +15,12 @@ logger = logging.getLogger("financial_notes_api")
|
|
| 13 |
|
| 14 |
app = FastAPI(
|
| 15 |
title="Financial Notes Generator API",
|
| 16 |
-
description="API for generating financial notes, balance sheets, cash flow statements, and P&L reports.",
|
| 17 |
version="1.0.0"
|
| 18 |
)
|
|
|
|
|
|
|
|
|
|
| 19 |
@app.on_event("startup")
|
| 20 |
async def startup_event():
|
| 21 |
logger.info("Financial Notes Generator API has started.")
|
|
@@ -81,34 +86,70 @@ async def llm_generate_and_excel(
|
|
| 81 |
|
| 82 |
|
| 83 |
@router.post("/notes")
|
| 84 |
-
async def notes_route(file: UploadFile = File(...)):
|
| 85 |
file_path = f"data/input/{file.filename}"
|
| 86 |
os.makedirs("data/input", exist_ok=True)
|
| 87 |
with open(file_path, "wb") as buffer:
|
| 88 |
shutil.copyfileobj(file.file, buffer)
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
if result["status"] == "success":
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 93 |
|
| 94 |
@router.post("/pnl")
|
| 95 |
-
async def pnl_route(file: UploadFile = File(...)):
|
| 96 |
file_path = f"data/input/{file.filename}"
|
| 97 |
os.makedirs("data/input", exist_ok=True)
|
| 98 |
with open(file_path, "wb") as buffer:
|
| 99 |
shutil.copyfileobj(file.file, buffer)
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
if result["status"] == "success":
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 104 |
|
| 105 |
@router.post("/bs")
|
| 106 |
-
async def bs_route(file: UploadFile = File(...)):
|
| 107 |
file_path = f"data/input/{file.filename}"
|
| 108 |
os.makedirs("data/input", exist_ok=True)
|
| 109 |
with open(file_path, "wb") as buffer:
|
| 110 |
shutil.copyfileobj(file.file, buffer)
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if result["status"] == "success":
|
| 113 |
# Use first xlsx file in output dir if present
|
| 114 |
output_file = result["result"].get("output_path")
|
|
@@ -120,19 +161,44 @@ async def bs_route(file: UploadFile = File(...)):
|
|
| 120 |
output_file = os.path.join(output_dir, xlsx_files[0])
|
| 121 |
else:
|
| 122 |
raise HTTPException(status_code=500, detail="No balance sheet Excel file produced")
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
else:
|
| 125 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 126 |
|
| 127 |
@router.post("/cf")
|
| 128 |
-
async def cf_route(file: UploadFile = File(...)):
|
| 129 |
file_path = f"data/input/{file.filename}"
|
| 130 |
os.makedirs("data/input", exist_ok=True)
|
| 131 |
with open(file_path, "wb") as buffer:
|
| 132 |
shutil.copyfileobj(file.file, buffer)
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
if result["status"] == "success":
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 137 |
app.include_router(router)
|
| 138 |
|
|
|
|
| 1 |
+
from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException, Query
|
| 2 |
from fastapi.responses import FileResponse
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import logging
|
| 6 |
from agents.langgraph import run_workflow
|
| 7 |
+
from agents.rlhf_workflows import run_rlhf_workflow
|
| 8 |
+
from agents.rlhf_routes import rlhf_router
|
| 9 |
|
| 10 |
# Configure logging for the application
|
| 11 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
title="Financial Notes Generator API",
|
| 18 |
+
description="API for generating financial notes, balance sheets, cash flow statements, and P&L reports with RLHF capabilities.",
|
| 19 |
version="1.0.0"
|
| 20 |
)
|
| 21 |
+
|
| 22 |
+
# Include RLHF routes
|
| 23 |
+
app.include_router(rlhf_router)
|
| 24 |
@app.on_event("startup")
|
| 25 |
async def startup_event():
|
| 26 |
logger.info("Financial Notes Generator API has started.")
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
@router.post("/notes")
|
| 89 |
+
async def notes_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
|
| 90 |
file_path = f"data/input/{file.filename}"
|
| 91 |
os.makedirs("data/input", exist_ok=True)
|
| 92 |
with open(file_path, "wb") as buffer:
|
| 93 |
shutil.copyfileobj(file.file, buffer)
|
| 94 |
+
|
| 95 |
+
# Choose workflow based on RLHF preference
|
| 96 |
+
if use_rlhf:
|
| 97 |
+
result = run_rlhf_workflow(file_path, "notes")
|
| 98 |
+
else:
|
| 99 |
+
result = run_workflow(file_path, "notes")
|
| 100 |
+
|
| 101 |
if result["status"] == "success":
|
| 102 |
+
response = FileResponse(result["result"]["output_xlsx_path"], filename=os.path.basename(result["result"]["output_xlsx_path"]))
|
| 103 |
+
|
| 104 |
+
# Add RLHF metadata to headers if available
|
| 105 |
+
if "rlhf_metadata" in result.get("result", {}):
|
| 106 |
+
rlhf_data = result["result"]["rlhf_metadata"]
|
| 107 |
+
response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
|
| 108 |
+
response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
|
| 109 |
+
response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
|
| 110 |
+
|
| 111 |
+
return response
|
| 112 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 113 |
|
| 114 |
@router.post("/pnl")
|
| 115 |
+
async def pnl_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
|
| 116 |
file_path = f"data/input/{file.filename}"
|
| 117 |
os.makedirs("data/input", exist_ok=True)
|
| 118 |
with open(file_path, "wb") as buffer:
|
| 119 |
shutil.copyfileobj(file.file, buffer)
|
| 120 |
+
|
| 121 |
+
# Choose workflow based on RLHF preference
|
| 122 |
+
if use_rlhf:
|
| 123 |
+
result = run_rlhf_workflow(file_path, "pnl")
|
| 124 |
+
else:
|
| 125 |
+
result = run_workflow(file_path, "pnl")
|
| 126 |
+
|
| 127 |
if result["status"] == "success":
|
| 128 |
+
response = FileResponse(result["result"].get("output_path", "data/pnl_statement.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/pnl_statement.xlsx")))
|
| 129 |
+
|
| 130 |
+
# Add RLHF metadata to headers if available
|
| 131 |
+
if "rlhf_metadata" in result.get("result", {}):
|
| 132 |
+
rlhf_data = result["result"]["rlhf_metadata"]
|
| 133 |
+
response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
|
| 134 |
+
response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
|
| 135 |
+
response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
|
| 136 |
+
|
| 137 |
+
return response
|
| 138 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 139 |
|
| 140 |
@router.post("/bs")
|
| 141 |
+
async def bs_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
|
| 142 |
file_path = f"data/input/{file.filename}"
|
| 143 |
os.makedirs("data/input", exist_ok=True)
|
| 144 |
with open(file_path, "wb") as buffer:
|
| 145 |
shutil.copyfileobj(file.file, buffer)
|
| 146 |
+
|
| 147 |
+
# Choose workflow based on RLHF preference
|
| 148 |
+
if use_rlhf:
|
| 149 |
+
result = run_rlhf_workflow(file_path, "bs")
|
| 150 |
+
else:
|
| 151 |
+
result = run_workflow(file_path, "bs")
|
| 152 |
+
|
| 153 |
if result["status"] == "success":
|
| 154 |
# Use first xlsx file in output dir if present
|
| 155 |
output_file = result["result"].get("output_path")
|
|
|
|
| 161 |
output_file = os.path.join(output_dir, xlsx_files[0])
|
| 162 |
else:
|
| 163 |
raise HTTPException(status_code=500, detail="No balance sheet Excel file produced")
|
| 164 |
+
|
| 165 |
+
response = FileResponse(output_file, filename=os.path.basename(output_file))
|
| 166 |
+
|
| 167 |
+
# Add RLHF metadata to headers if available
|
| 168 |
+
if "rlhf_metadata" in result.get("result", {}):
|
| 169 |
+
rlhf_data = result["result"]["rlhf_metadata"]
|
| 170 |
+
response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
|
| 171 |
+
response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
|
| 172 |
+
response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
|
| 173 |
+
|
| 174 |
+
return response
|
| 175 |
else:
|
| 176 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 177 |
|
| 178 |
@router.post("/cf")
|
| 179 |
+
async def cf_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)):
|
| 180 |
file_path = f"data/input/{file.filename}"
|
| 181 |
os.makedirs("data/input", exist_ok=True)
|
| 182 |
with open(file_path, "wb") as buffer:
|
| 183 |
shutil.copyfileobj(file.file, buffer)
|
| 184 |
+
|
| 185 |
+
# Choose workflow based on RLHF preference
|
| 186 |
+
if use_rlhf:
|
| 187 |
+
result = run_rlhf_workflow(file_path, "cf")
|
| 188 |
+
else:
|
| 189 |
+
result = run_workflow(file_path, "cf")
|
| 190 |
+
|
| 191 |
if result["status"] == "success":
|
| 192 |
+
response = FileResponse(result["result"].get("output_path", "data/cash_flow_statements.xlsx"), filename=os.path.basename(result["result"].get("output_path", "data/cash_flow_statements.xlsx")))
|
| 193 |
+
|
| 194 |
+
# Add RLHF metadata to headers if available
|
| 195 |
+
if "rlhf_metadata" in result.get("result", {}):
|
| 196 |
+
rlhf_data = result["result"]["rlhf_metadata"]
|
| 197 |
+
response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", ""))
|
| 198 |
+
response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", ""))
|
| 199 |
+
response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", ""))
|
| 200 |
+
|
| 201 |
+
return response
|
| 202 |
raise HTTPException(status_code=500, detail=result["error"])
|
| 203 |
app.include_router(router)
|
| 204 |
|
requirements.txt
CHANGED
|
@@ -13,4 +13,11 @@ langchain
|
|
| 13 |
langchain-openai
|
| 14 |
langchain-community
|
| 15 |
langchain-core
|
|
|
|
|
|
|
| 16 |
langgraph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
langchain-openai
|
| 14 |
langchain-community
|
| 15 |
langchain-core
|
| 16 |
+
|
| 17 |
+
#langgraph
|
| 18 |
langgraph
|
| 19 |
+
|
| 20 |
+
# RLHF dependencies
|
| 21 |
+
scikit-learn
|
| 22 |
+
numpy
|
| 23 |
+
joblib
|