Demos / backend /classes /galileo_platform.py
nikhile-galileo's picture
Added G2.0 changes
753e3c5
raw
history blame
3.74 kB
from typing import Optional
from dotenv import load_dotenv
from pydantic import BaseModel
from galileo import GalileoLogger, GalileoScorers, StageType
from galileo.protect import invoke_protect
from galileo.stages import create_protect_stage, get_protect_stage
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from galileo_core.schemas.protect.ruleset import Ruleset
load_dotenv()
class GalileoPlatformConfig(BaseModel):
"""Base configuration for Galileo platform."""
protect_project_name: str
protect_stage_name: str
class GalileoPlatform:
"""Implementation of Galileo Features"""
def __init__(self, config: GalileoPlatformConfig):
self.config = config
self.protect_stage_id = self.get_protect_stage_id()
def get_logger(self, project_name: str, logstream_name: str):
"""Get or create a Galileo Logger."""
return GalileoLogger(
project=project_name,
log_stream=logstream_name,
)
def get_protect_stage_id(self):
"""Get or create a Galileo Protect stage."""
try:
protect_stage = get_protect_stage(
project_name=self.config.protect_project_name,
stage_name=self.config.protect_stage_name,
)
return protect_stage.id
except Exception as _:
protect_stage = create_protect_stage(
project_name=self.config.protect_project_name,
name=self.config.protect_stage_name,
stage_type=StageType.local,
description="Deutsche Bank RFP RAG Protect Stage"
)
return protect_stage.id
def run_protect(self, input: str, output: str, logger: Optional[GalileoLogger] = None) -> dict:
"""Run Galileo Protect on input and output."""
response = invoke_protect(
payload=Payload(input=input, output=output),
prioritized_rulesets=[
Ruleset(
rules=[
Rule(
metric=GalileoScorers.context_adherence_luna,
operator=RuleOperator.lte,
target_value=0.01,
),
],
action=OverrideAction(
choices=["Sorry, the input is hallucinatory."]
),
),
Ruleset(
rules=[
Rule(
metric=GalileoScorers.input_pii,
operator=RuleOperator.any,
target_value=["email", "phone_number", "name"],
)
],
action=OverrideAction(
choices=["Sorry, the output contains PII."]
),
),
Ruleset(
rules=[
Rule(
metric="deutsche_bank_company_pii_scorer_0",
operator=RuleOperator.gte,
target_value=0.1,
)
],
action=OverrideAction(
choices=["Sorry, the output contains PII."]
),
)
],
stage_id=self.protect_stage_id,
)
if logger:
logger.add_protect_span(
payload=Payload(input=input, output=output),
response=response,
)
return dict(response)