finryver-dev / agents /reward_model.py
Sahil Garg
udf added to /notes-llm alongwith RLHF
6611563
raw
history blame
4.99 kB
"""
Enhanced Text-Based RLHF Reward Model for FinRyver
Focuses on collecting and analyzing specific feedback content instead of predicting quality scores
"""
import json
import os
import logging
from typing import Dict, Any, List, Optional
import time
logger = logging.getLogger(__name__)
class TextBasedRewardModel:
"""
Simple reward model that collects and analyzes text-based feedback
"""
def __init__(self, model_dir: str = "data/models"):
self.model_dir = model_dir
self.feedback_data_path = os.path.join(model_dir, "feedback_data.json")
os.makedirs(model_dir, exist_ok=True)
self.feedback_data = []
self.is_trained = False
self.model_version = "2.0-text-based"
# Load existing feedback data if available
self._load_feedback_data()
def collect_feedback(self, feedback_data: Dict[str, Any]) -> Dict[str, Any]:
"""Collect and store text-based feedback"""
# Validate that we have text feedback
text_feedback = []
if feedback_data.get('specific_errors', '').strip():
text_feedback.append(feedback_data['specific_errors'])
if feedback_data.get('missing_items', '').strip():
text_feedback.append(feedback_data['missing_items'])
if feedback_data.get('improvement_suggestions', '').strip():
text_feedback.append(feedback_data['improvement_suggestions'])
if not text_feedback:
return {"error": "No text feedback provided"}
# Store feedback
feedback_entry = {
"timestamp": time.time(),
"statement_id": feedback_data.get("statement_id"),
"reviewer_id": feedback_data.get("reviewer_id", "anonymous"),
"statement_type": feedback_data.get("statement_type"),
"specific_errors": feedback_data.get("specific_errors", ""),
"missing_items": feedback_data.get("missing_items", ""),
"improvement_suggestions": feedback_data.get("improvement_suggestions", ""),
"would_accept_for_audit": feedback_data.get("would_accept_for_audit", False),
"complexity_level": feedback_data.get("complexity_level", "medium")
}
self.feedback_data.append(feedback_entry)
self._save_feedback_data()
return {
"status": "success",
"feedback_stored": True,
"total_feedback": len(self.feedback_data)
}
def get_feedback_patterns(self) -> Dict[str, Any]:
"""Get patterns and insights from collected feedback"""
if not self.feedback_data:
return {"error": "No feedback data available"}
# Analyze feedback patterns
patterns = {
"total_feedback": len(self.feedback_data),
"statement_types": {},
"common_issues": [],
"improvement_suggestions": [],
"acceptance_rate": 0.0
}
# Count statement types
statement_counts = {}
acceptance_count = 0
for feedback in self.feedback_data:
stmt_type = feedback.get("statement_type", "unknown")
statement_counts[stmt_type] = statement_counts.get(stmt_type, 0) + 1
if feedback.get("would_accept_for_audit"):
acceptance_count += 1
# Collect common issues
if feedback.get("specific_errors"):
patterns["common_issues"].append(feedback["specific_errors"])
if feedback.get("missing_items"):
patterns["common_issues"].append(feedback["missing_items"])
if feedback.get("improvement_suggestions"):
patterns["improvement_suggestions"].append(feedback["improvement_suggestions"])
patterns["statement_types"] = statement_counts
patterns["acceptance_rate"] = acceptance_count / len(self.feedback_data) if self.feedback_data else 0
return patterns
def get_recent_feedback(self, limit: int = 10) -> List[Dict[str, Any]]:
"""Get recent feedback entries"""
return self.feedback_data[-limit:] if self.feedback_data else []
def _save_feedback_data(self):
"""Save feedback data to disk"""
try:
with open(self.feedback_data_path, 'w') as f:
json.dump(self.feedback_data, f, indent=2)
except Exception as e:
logger.error(f"Error saving feedback data: {e}")
def _load_feedback_data(self):
"""Load feedback data from disk"""
try:
if os.path.exists(self.feedback_data_path):
with open(self.feedback_data_path, 'r') as f:
self.feedback_data = json.load(f)
logger.info(f"Loaded {len(self.feedback_data)} feedback entries")
except Exception as e:
logger.warning(f"Error loading feedback data: {e}")
self.feedback_data = []
# Backward compatibility alias
FinancialRewardModel = TextBasedRewardModel