Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Text-Based RLHF Reward Model for FinRyver | |
| Focuses on collecting and analyzing specific feedback content instead of predicting quality scores | |
| """ | |
| import json | |
| import os | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class TextBasedRewardModel: | |
| """ | |
| Simple reward model that collects and analyzes text-based feedback | |
| """ | |
| def __init__(self, model_dir: str = "data/models"): | |
| self.model_dir = model_dir | |
| self.feedback_data_path = os.path.join(model_dir, "feedback_data.json") | |
| os.makedirs(model_dir, exist_ok=True) | |
| self.feedback_data = [] | |
| self.is_trained = False | |
| self.model_version = "2.0-text-based" | |
| # Load existing feedback data if available | |
| self._load_feedback_data() | |
| def collect_feedback(self, feedback_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Collect and store text-based feedback""" | |
| # Validate that we have text feedback | |
| text_feedback = [] | |
| if feedback_data.get('specific_errors', '').strip(): | |
| text_feedback.append(feedback_data['specific_errors']) | |
| if feedback_data.get('missing_items', '').strip(): | |
| text_feedback.append(feedback_data['missing_items']) | |
| if feedback_data.get('improvement_suggestions', '').strip(): | |
| text_feedback.append(feedback_data['improvement_suggestions']) | |
| if not text_feedback: | |
| return {"error": "No text feedback provided"} | |
| # Store feedback | |
| feedback_entry = { | |
| "timestamp": time.time(), | |
| "statement_id": feedback_data.get("statement_id"), | |
| "reviewer_id": feedback_data.get("reviewer_id", "anonymous"), | |
| "statement_type": feedback_data.get("statement_type"), | |
| "specific_errors": feedback_data.get("specific_errors", ""), | |
| "missing_items": feedback_data.get("missing_items", ""), | |
| "improvement_suggestions": feedback_data.get("improvement_suggestions", ""), | |
| "would_accept_for_audit": feedback_data.get("would_accept_for_audit", False), | |
| "complexity_level": feedback_data.get("complexity_level", "medium") | |
| } | |
| self.feedback_data.append(feedback_entry) | |
| self._save_feedback_data() | |
| return { | |
| "status": "success", | |
| "feedback_stored": True, | |
| "total_feedback": len(self.feedback_data) | |
| } | |
| def get_feedback_patterns(self) -> Dict[str, Any]: | |
| """Get patterns and insights from collected feedback""" | |
| if not self.feedback_data: | |
| return {"error": "No feedback data available"} | |
| # Analyze feedback patterns | |
| patterns = { | |
| "total_feedback": len(self.feedback_data), | |
| "statement_types": {}, | |
| "common_issues": [], | |
| "improvement_suggestions": [], | |
| "acceptance_rate": 0.0 | |
| } | |
| # Count statement types | |
| statement_counts = {} | |
| acceptance_count = 0 | |
| for feedback in self.feedback_data: | |
| stmt_type = feedback.get("statement_type", "unknown") | |
| statement_counts[stmt_type] = statement_counts.get(stmt_type, 0) + 1 | |
| if feedback.get("would_accept_for_audit"): | |
| acceptance_count += 1 | |
| # Collect common issues | |
| if feedback.get("specific_errors"): | |
| patterns["common_issues"].append(feedback["specific_errors"]) | |
| if feedback.get("missing_items"): | |
| patterns["common_issues"].append(feedback["missing_items"]) | |
| if feedback.get("improvement_suggestions"): | |
| patterns["improvement_suggestions"].append(feedback["improvement_suggestions"]) | |
| patterns["statement_types"] = statement_counts | |
| patterns["acceptance_rate"] = acceptance_count / len(self.feedback_data) if self.feedback_data else 0 | |
| return patterns | |
| def get_recent_feedback(self, limit: int = 10) -> List[Dict[str, Any]]: | |
| """Get recent feedback entries""" | |
| return self.feedback_data[-limit:] if self.feedback_data else [] | |
| def _save_feedback_data(self): | |
| """Save feedback data to disk""" | |
| try: | |
| with open(self.feedback_data_path, 'w') as f: | |
| json.dump(self.feedback_data, f, indent=2) | |
| except Exception as e: | |
| logger.error(f"Error saving feedback data: {e}") | |
| def _load_feedback_data(self): | |
| """Load feedback data from disk""" | |
| try: | |
| if os.path.exists(self.feedback_data_path): | |
| with open(self.feedback_data_path, 'r') as f: | |
| self.feedback_data = json.load(f) | |
| logger.info(f"Loaded {len(self.feedback_data)} feedback entries") | |
| except Exception as e: | |
| logger.warning(f"Error loading feedback data: {e}") | |
| self.feedback_data = [] | |
| # Backward compatibility alias | |
| FinancialRewardModel = TextBasedRewardModel | |