Spaces:
Sleeping
Sleeping
Update agents/generator_validator.py
#1
by
sahil-1-garg
- opened
agents/generator_validator.py
CHANGED
|
@@ -433,9 +433,10 @@ class BaseValidator(ABC):
|
|
| 433 |
class LLMNotesGenerator(BaseGenerator):
|
| 434 |
"""Generator for AI-powered financial notes"""
|
| 435 |
|
| 436 |
-
def __init__(self, max_attempts: int = 3, use_rlhf: bool = False):
|
| 437 |
super().__init__(max_attempts)
|
| 438 |
self.use_rlhf = use_rlhf
|
|
|
|
| 439 |
|
| 440 |
def generate(self, file_path: str, **kwargs) -> GenerationResult:
|
| 441 |
"""Generate notes using AI/LLM approach with feedback integration"""
|
|
@@ -451,10 +452,10 @@ class LLMNotesGenerator(BaseGenerator):
|
|
| 451 |
# Choose workflow based on RLHF preference
|
| 452 |
if self.use_rlhf:
|
| 453 |
from agents.rlhf_workflows import run_rlhf_workflow
|
| 454 |
-
result = run_rlhf_workflow(file_path, "notes-llm")
|
| 455 |
else:
|
| 456 |
from agents.langgraph import run_workflow
|
| 457 |
-
result = run_workflow(file_path, "notes-llm", feedback_context=feedback_context)
|
| 458 |
|
| 459 |
if result["status"] == "success":
|
| 460 |
# UDFs are now applied in generate_llm_notes function before Excel conversion
|
|
@@ -544,84 +545,6 @@ class LLMNotesGenerator(BaseGenerator):
|
|
| 544 |
# 2. Retry with different parameters
|
| 545 |
# 3. Use fallback models
|
| 546 |
|
| 547 |
-
if not self.use_rlhf and "quality" in str(feedback).lower():
|
| 548 |
-
# If quality issues and not using RLHF, try RLHF
|
| 549 |
-
logger.info("Switching to RLHF for better quality")
|
| 550 |
-
original_rlhf = self.use_rlhf
|
| 551 |
-
self.use_rlhf = True
|
| 552 |
-
result = self.generate(previous_result.data.get("file_path") if previous_result.data else None)
|
| 553 |
-
self.use_rlhf = original_rlhf # Reset for future calls
|
| 554 |
-
return result
|
| 555 |
-
else:
|
| 556 |
-
# Otherwise, just retry
|
| 557 |
-
return self.generate(previous_result.data.get("file_path") if previous_result.data else None)
|
| 558 |
-
"""Generator for AI-powered financial notes"""
|
| 559 |
-
|
| 560 |
-
def __init__(self, max_attempts: int = 3, use_rlhf: bool = False):
|
| 561 |
-
super().__init__(max_attempts)
|
| 562 |
-
self.use_rlhf = use_rlhf
|
| 563 |
-
|
| 564 |
-
def generate(self, file_path: str, **kwargs) -> GenerationResult:
|
| 565 |
-
"""Generate notes using AI/LLM approach"""
|
| 566 |
-
try:
|
| 567 |
-
self.attempts_made += 1
|
| 568 |
-
execution_id = f"notes_llm_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.attempts_made}"
|
| 569 |
-
|
| 570 |
-
# Choose workflow based on RLHF preference
|
| 571 |
-
if self.use_rlhf:
|
| 572 |
-
from agents.rlhf_workflows import run_rlhf_workflow
|
| 573 |
-
result = run_rlhf_workflow(file_path, "notes-llm")
|
| 574 |
-
else:
|
| 575 |
-
from agents.langgraph import run_workflow
|
| 576 |
-
result = run_workflow(file_path, "notes-llm")
|
| 577 |
-
|
| 578 |
-
if result["status"] == "success":
|
| 579 |
-
return GenerationResult(
|
| 580 |
-
success=True,
|
| 581 |
-
output_path=result["result"]["output_xlsx_path"],
|
| 582 |
-
data=result["result"],
|
| 583 |
-
error=None,
|
| 584 |
-
metadata={
|
| 585 |
-
"execution_id": execution_id,
|
| 586 |
-
"generation_method": "llm",
|
| 587 |
-
"use_rlhf": self.use_rlhf,
|
| 588 |
-
"attempt": self.attempts_made,
|
| 589 |
-
"rlhf_metadata": result["result"].get("rlhf_metadata", {})
|
| 590 |
-
}
|
| 591 |
-
)
|
| 592 |
-
else:
|
| 593 |
-
return GenerationResult(
|
| 594 |
-
success=False,
|
| 595 |
-
output_path=None,
|
| 596 |
-
data=None,
|
| 597 |
-
error=result.get("error", "Unknown error"),
|
| 598 |
-
metadata={
|
| 599 |
-
"execution_id": execution_id,
|
| 600 |
-
"generation_method": "llm",
|
| 601 |
-
"use_rlhf": self.use_rlhf,
|
| 602 |
-
"attempt": self.attempts_made
|
| 603 |
-
}
|
| 604 |
-
)
|
| 605 |
-
|
| 606 |
-
except Exception as e:
|
| 607 |
-
logger.error(f"LLM Notes generation failed: {e}")
|
| 608 |
-
return GenerationResult(
|
| 609 |
-
success=False,
|
| 610 |
-
output_path=None,
|
| 611 |
-
data=None,
|
| 612 |
-
error=str(e),
|
| 613 |
-
metadata={
|
| 614 |
-
"execution_id": f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 615 |
-
"generation_method": "llm",
|
| 616 |
-
"use_rlhf": self.use_rlhf,
|
| 617 |
-
"attempt": self.attempts_made
|
| 618 |
-
}
|
| 619 |
-
)
|
| 620 |
-
|
| 621 |
-
def refine(self, previous_result: GenerationResult, feedback: List[str]) -> GenerationResult:
|
| 622 |
-
"""Refine LLM notes generation based on feedback"""
|
| 623 |
-
logger.info(f"Refining LLM notes generation with feedback: {feedback}")
|
| 624 |
-
|
| 625 |
if not self.use_rlhf and "quality" in str(feedback).lower():
|
| 626 |
# If quality issues and not using RLHF, try RLHF
|
| 627 |
logger.info("Switching to RLHF for better quality")
|
|
@@ -752,9 +675,9 @@ class GeneratorValidatorPipeline:
|
|
| 752 |
"validation_criteria": self.validator.get_validation_criteria()
|
| 753 |
}
|
| 754 |
|
| 755 |
-
def create_notes_pipeline(use_rlhf: bool = False) -> GeneratorValidatorPipeline:
|
| 756 |
"""Factory function to create LLM-based pipeline for notes generation"""
|
| 757 |
-
generator = LLMNotesGenerator(use_rlhf=use_rlhf)
|
| 758 |
validator = NotesValidator()
|
| 759 |
|
| 760 |
return GeneratorValidatorPipeline(generator, validator)
|
|
|
|
| 433 |
class LLMNotesGenerator(BaseGenerator):
|
| 434 |
"""Generator for AI-powered financial notes"""
|
| 435 |
|
| 436 |
+
def __init__(self, max_attempts: int = 3, use_rlhf: bool = False, user_api_key: Optional[str] = None):
|
| 437 |
super().__init__(max_attempts)
|
| 438 |
self.use_rlhf = use_rlhf
|
| 439 |
+
self.user_api_key = user_api_key
|
| 440 |
|
| 441 |
def generate(self, file_path: str, **kwargs) -> GenerationResult:
|
| 442 |
"""Generate notes using AI/LLM approach with feedback integration"""
|
|
|
|
| 452 |
# Choose workflow based on RLHF preference
|
| 453 |
if self.use_rlhf:
|
| 454 |
from agents.rlhf_workflows import run_rlhf_workflow
|
| 455 |
+
result = run_rlhf_workflow(file_path, "notes-llm", user_api_key=self.user_api_key)
|
| 456 |
else:
|
| 457 |
from agents.langgraph import run_workflow
|
| 458 |
+
result = run_workflow(file_path, "notes-llm", feedback_context=feedback_context, user_api_key=self.user_api_key)
|
| 459 |
|
| 460 |
if result["status"] == "success":
|
| 461 |
# UDFs are now applied in generate_llm_notes function before Excel conversion
|
|
|
|
| 545 |
# 2. Retry with different parameters
|
| 546 |
# 3. Use fallback models
|
| 547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
if not self.use_rlhf and "quality" in str(feedback).lower():
|
| 549 |
# If quality issues and not using RLHF, try RLHF
|
| 550 |
logger.info("Switching to RLHF for better quality")
|
|
|
|
| 675 |
"validation_criteria": self.validator.get_validation_criteria()
|
| 676 |
}
|
| 677 |
|
| 678 |
+
def create_notes_pipeline(use_rlhf: bool = False, user_api_key: Optional[str] = None) -> GeneratorValidatorPipeline:
|
| 679 |
"""Factory function to create LLM-based pipeline for notes generation"""
|
| 680 |
+
generator = LLMNotesGenerator(use_rlhf=use_rlhf, user_api_key=user_api_key)
|
| 681 |
validator = NotesValidator()
|
| 682 |
|
| 683 |
return GeneratorValidatorPipeline(generator, validator)
|