Spaces:
Sleeping
Sleeping
| """ | |
| RLHF Feedback Management System for FinRyver | |
| Handles collection, storage, and management of human feedback on financial statements | |
| """ | |
| import json | |
| import os | |
| import time | |
| import uuid | |
| from typing import Dict, Any, List, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class FeedbackManager: | |
| """Manages human feedback collection for RLHF training""" | |
| def __init__(self, feedback_dir: str = "data/feedback"): | |
| self.feedback_dir = feedback_dir | |
| self.feedback_db = os.path.join(feedback_dir, "human_feedback.json") | |
| self.statements_db = os.path.join(feedback_dir, "generated_statements.json") | |
| os.makedirs(feedback_dir, exist_ok=True) | |
| def store_generated_statement(self, statement_data: Dict[str, Any]) -> str: | |
| """Store generated statement for later feedback collection""" | |
| statement_id = str(uuid.uuid4()) | |
| statement_record = { | |
| "statement_id": statement_id, | |
| "timestamp": time.time(), | |
| "statement_type": statement_data.get("type", "unknown"), | |
| "file_path": statement_data.get("file_path"), | |
| "output_path": statement_data.get("output_path"), | |
| "generation_time": statement_data.get("generation_time", 0), | |
| "metadata": statement_data.get("metadata", {}) | |
| } | |
| # Load existing statements | |
| statements = self._load_statements() | |
| statements.append(statement_record) | |
| # Save updated statements | |
| with open(self.statements_db, "w") as f: | |
| json.dump(statements, f, indent=2) | |
| logger.info(f"Stored statement {statement_id} for feedback collection") | |
| return statement_id | |
| def store_feedback(self, feedback: Dict[str, Any]) -> str: | |
| """Store human feedback for RLHF training""" | |
| feedback_id = str(uuid.uuid4()) | |
| feedback_record = { | |
| "feedback_id": feedback_id, | |
| "statement_id": feedback.get("statement_id"), | |
| "timestamp": time.time(), | |
| "reviewer_id": feedback.get("reviewer_id", "anonymous"), | |
| # Qualitative feedback | |
| "specific_errors": feedback.get("specific_errors", ""), | |
| "missing_items": feedback.get("missing_items", ""), | |
| "improvement_suggestions": feedback.get("improvement_suggestions", ""), | |
| "would_accept_for_audit": feedback.get("would_accept_for_audit", False), | |
| # Additional context | |
| "statement_type": feedback.get("statement_type"), | |
| "complexity_level": feedback.get("complexity_level", "medium") | |
| } | |
| # Load existing feedback | |
| all_feedback = self._load_feedback() | |
| all_feedback.append(feedback_record) | |
| # Save updated feedback | |
| with open(self.feedback_db, "w") as f: | |
| json.dump(all_feedback, f, indent=2) | |
| logger.info(f"Stored feedback {feedback_id} for statement {feedback.get('statement_id')}") | |
| return feedback_id | |
| def get_training_data(self, min_feedback_count: int = 2) -> List[Dict[str, Any]]: | |
| """Get feedback data suitable for RLHF training""" | |
| feedback_data = self._load_feedback() | |
| if len(feedback_data) < min_feedback_count: | |
| logger.warning(f"Only {len(feedback_data)} feedback samples available, need at least {min_feedback_count}") | |
| return [] | |
| # Filter and prepare training data | |
| training_data = [] | |
| for feedback in feedback_data: | |
| training_sample = { | |
| "statement_id": feedback["statement_id"], | |
| "statement_type": feedback["statement_type"], | |
| "binary_approval": feedback["would_accept_for_audit"], | |
| "feedback_text": { | |
| "errors": feedback.get("specific_errors", ""), | |
| "missing": feedback.get("missing_items", ""), | |
| "suggestions": feedback.get("improvement_suggestions", "") | |
| } | |
| } | |
| training_data.append(training_sample) | |
| return training_data | |
| def get_statement_for_review(self, statement_id: str) -> Optional[Dict[str, Any]]: | |
| """Get statement data for human review""" | |
| statements = self._load_statements() | |
| for statement in statements: | |
| if statement["statement_id"] == statement_id: | |
| return statement | |
| return None | |
| def get_pending_reviews(self, limit: int = 10) -> List[Dict[str, Any]]: | |
| """Get statements that need human review""" | |
| statements = self._load_statements() | |
| feedback_data = self._load_feedback() | |
| # Get statement IDs that already have feedback | |
| reviewed_ids = {fb["statement_id"] for fb in feedback_data} | |
| # Return statements without feedback | |
| pending = [s for s in statements if s["statement_id"] not in reviewed_ids] | |
| return pending[-limit:] # Return most recent | |
| def get_feedback_stats(self) -> Dict[str, Any]: | |
| """Get statistics about collected feedback""" | |
| feedback_data = self._load_feedback() | |
| statements = self._load_statements() | |
| if not feedback_data: | |
| return {"total_feedback": 0, "total_statements": len(statements)} | |
| # Calculate statistics | |
| audit_approvals = [fb["would_accept_for_audit"] for fb in feedback_data] | |
| stats = { | |
| "total_feedback": len(feedback_data), | |
| "total_statements": len(statements), | |
| "audit_approval_rate": sum(audit_approvals) / len(audit_approvals) if audit_approvals else 0, | |
| "feedback_by_type": {} | |
| } | |
| # Group by statement type | |
| for fb in feedback_data: | |
| stmt_type = fb.get("statement_type", "unknown") | |
| if stmt_type not in stats["feedback_by_type"]: | |
| stats["feedback_by_type"][stmt_type] = {"count": 0} | |
| stats["feedback_by_type"][stmt_type]["count"] += 1 | |
| return stats | |
| def _load_feedback(self) -> List[Dict[str, Any]]: | |
| """Load feedback from storage""" | |
| if os.path.exists(self.feedback_db): | |
| try: | |
| with open(self.feedback_db, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| logger.warning("Could not load feedback database, starting fresh") | |
| return [] | |
| def _load_statements(self) -> List[Dict[str, Any]]: | |
| """Load statements from storage""" | |
| if os.path.exists(self.statements_db): | |
| try: | |
| with open(self.statements_db, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| logger.warning("Could not load statements database, starting fresh") | |
| return [] | |