Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel | |
| from galileo import GalileoLogger, GalileoScorers, StageType | |
| from galileo.protect import invoke_protect | |
| from galileo.stages import create_protect_stage, get_protect_stage | |
| from galileo_core.schemas.protect.action import OverrideAction | |
| from galileo_core.schemas.protect.payload import Payload | |
| from galileo_core.schemas.protect.rule import Rule, RuleOperator | |
| from galileo_core.schemas.protect.ruleset import Ruleset | |
| load_dotenv() | |
| class GalileoPlatformConfig(BaseModel): | |
| """Base configuration for Galileo platform.""" | |
| protect_project_name: str | |
| protect_stage_name: str | |
| class GalileoPlatform: | |
| """Implementation of Galileo Features""" | |
| def __init__(self, config: GalileoPlatformConfig): | |
| self.config = config | |
| self.protect_stage_id = self.get_protect_stage_id() | |
| def get_logger(self, project_name: str, logstream_name: str): | |
| """Get or create a Galileo Logger.""" | |
| return GalileoLogger( | |
| project=project_name, | |
| log_stream=logstream_name, | |
| ) | |
| def get_protect_stage_id(self): | |
| """Get or create a Galileo Protect stage.""" | |
| try: | |
| protect_stage = get_protect_stage( | |
| project_name=self.config.protect_project_name, | |
| stage_name=self.config.protect_stage_name, | |
| ) | |
| return protect_stage.id | |
| except Exception as _: | |
| protect_stage = create_protect_stage( | |
| project_name=self.config.protect_project_name, | |
| name=self.config.protect_stage_name, | |
| stage_type=StageType.local, | |
| description="Deutsche Bank RFP RAG Protect Stage" | |
| ) | |
| return protect_stage.id | |
| def run_protect(self, input: str, output: str, logger: Optional[GalileoLogger] = None) -> dict: | |
| """Run Galileo Protect on input and output.""" | |
| response = invoke_protect( | |
| payload=Payload(input=input, output=output), | |
| prioritized_rulesets=[ | |
| Ruleset( | |
| rules=[ | |
| Rule( | |
| metric=GalileoScorers.context_adherence_luna, | |
| operator=RuleOperator.lte, | |
| target_value=0.01, | |
| ), | |
| ], | |
| action=OverrideAction( | |
| choices=["Sorry, the input is hallucinatory."] | |
| ), | |
| ), | |
| Ruleset( | |
| rules=[ | |
| Rule( | |
| metric=GalileoScorers.input_pii, | |
| operator=RuleOperator.any, | |
| target_value=["email", "phone_number", "name"], | |
| ) | |
| ], | |
| action=OverrideAction( | |
| choices=["Sorry, the output contains PII."] | |
| ), | |
| ), | |
| Ruleset( | |
| rules=[ | |
| Rule( | |
| metric="deutsche_bank_company_pii_scorer_0", | |
| operator=RuleOperator.gte, | |
| target_value=0.1, | |
| ) | |
| ], | |
| action=OverrideAction( | |
| choices=["Sorry, the output contains PII."] | |
| ), | |
| ) | |
| ], | |
| stage_id=self.protect_stage_id, | |
| ) | |
| if logger: | |
| logger.add_protect_span( | |
| payload=Payload(input=input, output=output), | |
| response=response, | |
| ) | |
| return dict(response) |