Jayashree Sridhar commited on
Commit
961a175
·
1 Parent(s): 4b2ca0c

modified validation tools

Browse files
Files changed (1) hide show
  1. agents/tools/validation_tools.py +9 -3
agents/tools/validation_tools.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass
8
  import json
9
  from transformers import pipeline
10
  import torch
 
11
 
12
  # @dataclass
13
  class ValidationResult:
@@ -399,8 +400,12 @@ class ValidationResult:
399
  from .base_tool import BaseTool
400
 
401
  class ValidateResponseTool(BaseTool):
402
- def __init__(self, config=None):
403
- super().__init__(config)
 
 
 
 
404
  # ... any required initialization ...
405
  def __call__(self, response: str, context: dict = None):
406
  # Place your actual validation logic here, include dummy for illustration
@@ -415,8 +420,9 @@ class ValidateResponseTool(BaseTool):
415
  return {"is_valid", "issues", "warnings", "suggestions","confidence","refined_text"}
416
 
417
  class ValidationTools:
 
418
  def __init__(self, config=None):
419
- self.validate_response = ValidateResponseTool(config)
420
  # Add more tools as needed (check_safety, refine_response, etc.)
421
  # # Initialize sentiment analyzer for tone checking
422
  self.sentiment_analyzer = pipeline(
 
8
  import json
9
  from transformers import pipeline
10
  import torch
11
+ from pydantic import BaseModel, PrivateAttr
12
 
13
  # @dataclass
14
  class ValidationResult:
 
400
  from .base_tool import BaseTool
401
 
402
  class ValidateResponseTool(BaseTool):
403
+ name: str = "validate_response"
404
+ description: str = "Validates safety and helpfulness."
405
+ _config: object = PrivateAttr()
406
+ def __init__(self, config=None, **data):
407
+ super().__init__(**data)
408
+ self._config = config
409
  # ... any required initialization ...
410
  def __call__(self, response: str, context: dict = None):
411
  # Place your actual validation logic here, include dummy for illustration
 
420
  return {"is_valid", "issues", "warnings", "suggestions","confidence","refined_text"}
421
 
422
  class ValidationTools:
423
+ _model: ValidateResponseTool = PrivateAttr()
424
  def __init__(self, config=None):
425
+ self._validate_response = ValidateResponseTool(config)
426
  # Add more tools as needed (check_safety, refine_response, etc.)
427
  # # Initialize sentiment analyzer for tone checking
428
  self.sentiment_analyzer = pipeline(