from pydantic import BaseModel, UUID4 from dotenv import load_dotenv from typing import Any from galileo import GalileoLogger, GalileoScorers, StageType from galileo.log_streams import create_log_stream, get_log_stream from galileo.projects import create_project, get_project from galileo.protect import invoke_protect from galileo.stages import create_protect_stage, get_protect_stage from galileo_core.schemas.protect.action import OverrideAction from galileo_core.schemas.protect.payload import Payload from galileo_core.schemas.protect.rule import Rule, RuleOperator from galileo_core.schemas.protect.ruleset import Ruleset load_dotenv() class GalileoPlatformConfig(BaseModel): """Base configuration for Galileo platform.""" protect_project_name: str protect_stage_name: str class GalileoPlatform: """Implementation of Galileo Features""" def __init__(self, config: GalileoPlatformConfig): self.config = config self.protect_stage_id = self.get_protect_stage_id() def get_logger(self, project_name: str, logstream_name: str) -> GalileoLogger | None: """Get or create a Galileo Logger.""" if not get_project(name=project_name): _ = create_project(name=project_name) print(f"Project {project_name} created") if not get_log_stream(name=logstream_name, project_name=project_name): _ = create_log_stream(name=logstream_name, project_name=project_name) print(f"Logstream {logstream_name} created in project {project_name}") try: return GalileoLogger( project=project_name, log_stream=logstream_name, ) except Exception as e: print(f"Failed to create logger: {e}") return None def get_protect_stage_id(self) -> str | UUID4 | None: """Get or create a Galileo Protect stage.""" if not get_project(name=self.config.protect_project_name): _ = create_project(name=self.config.protect_project_name) print(f"Project {self.config.protect_project_name} created") protect_stage = get_protect_stage( project_name=self.config.protect_project_name, stage_name=self.config.protect_stage_name, ) if not protect_stage: _ = create_protect_stage( project_name=self.config.protect_project_name, name=self.config.protect_stage_name, stage_type=StageType.local ) protect_stage = get_protect_stage( project_name=self.config.protect_project_name, stage_name=self.config.protect_stage_name, ) return protect_stage.id def run_protect(self, input: str, output: str, logger: GalileoLogger | None = None) -> dict[Any, Any]: """Run Galileo Protect on input and output.""" response = invoke_protect( payload=Payload(input=input, output=output), prioritized_rulesets=[ Ruleset( rules=[ Rule( metric=GalileoScorers.context_adherence_luna, operator=RuleOperator.lte, target_value=0.01, ), ], action=OverrideAction( choices=["Sorry, the input is hallucinatory."] ), ), Ruleset( rules=[ Rule( metric=GalileoScorers.input_pii, operator=RuleOperator.any, target_value=["email", "phone_number", "name"], ) ], action=OverrideAction( choices=["Sorry, the output contains PII."] ), ), Ruleset( rules=[ Rule( metric="deutsche_bank_company_pii_scorer_0", operator=RuleOperator.gte, target_value=0.1, ) ], action=OverrideAction( choices=["Sorry, the output contains PII."] ), ) ], stage_id=self.protect_stage_id, ) if logger: logger.add_protect_span( payload=Payload(input=input, output=output), response=response, ) return dict(response)