File size: 3,738 Bytes
e68d535
753e3c5
 
 
 
 
 
 
 
 
 
 
 
e68d535
 
 
 
 
 
 
 
 
 
 
 
753e3c5
e68d535
753e3c5
 
 
 
 
e68d535
 
753e3c5
e68d535
 
753e3c5
 
 
e68d535
753e3c5
e68d535
753e3c5
 
e68d535
753e3c5
 
e68d535
753e3c5
e68d535
753e3c5
e68d535
753e3c5
 
e68d535
753e3c5
e68d535
753e3c5
 
 
e68d535
 
 
753e3c5
e68d535
 
 
753e3c5
e68d535
753e3c5
 
 
e68d535
 
 
753e3c5
e68d535
 
 
753e3c5
 
 
 
 
 
 
 
 
 
 
 
e68d535
 
 
 
753e3c5
 
 
e68d535
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Optional

from dotenv import load_dotenv
from pydantic import BaseModel

from galileo import GalileoLogger, GalileoScorers, StageType
from galileo.protect import invoke_protect
from galileo.stages import create_protect_stage, get_protect_stage
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from galileo_core.schemas.protect.ruleset import Ruleset

load_dotenv()

class GalileoPlatformConfig(BaseModel):
    """Base configuration for Galileo platform."""
    protect_project_name: str
    protect_stage_name: str

class GalileoPlatform:
    """Implementation of Galileo Features"""

    def __init__(self, config: GalileoPlatformConfig):
        self.config = config
        self.protect_stage_id = self.get_protect_stage_id()

    def get_logger(self, project_name: str, logstream_name: str):
        """Get or create a Galileo Logger."""
        return GalileoLogger(
            project=project_name,
            log_stream=logstream_name,
        )

    def get_protect_stage_id(self):
        """Get or create a Galileo Protect stage."""
        try:
            protect_stage = get_protect_stage(
                project_name=self.config.protect_project_name,
                stage_name=self.config.protect_stage_name,
            )
            return protect_stage.id
        except Exception as _:
            protect_stage = create_protect_stage(
                project_name=self.config.protect_project_name,
                name=self.config.protect_stage_name,
                stage_type=StageType.local,
                description="Deutsche Bank RFP RAG Protect Stage"
            )
            return protect_stage.id

    def run_protect(self, input: str, output: str, logger: Optional[GalileoLogger] = None) -> dict:
        """Run Galileo Protect on input and output."""
        response = invoke_protect(
            payload=Payload(input=input, output=output),
            prioritized_rulesets=[
                Ruleset(
                    rules=[
                        Rule(
                            metric=GalileoScorers.context_adherence_luna,
                            operator=RuleOperator.lte,
                            target_value=0.01,
                        ),
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the input is hallucinatory."]
                    ),
                ),
                Ruleset(
                    rules=[
                        Rule(
                            metric=GalileoScorers.input_pii,
                            operator=RuleOperator.any,
                            target_value=["email", "phone_number", "name"],
                        )
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the output contains PII."]
                    ),
                ),
                Ruleset(
                    rules=[
                        Rule(
                            metric="deutsche_bank_company_pii_scorer_0",
                            operator=RuleOperator.gte,
                            target_value=0.1,
                        )
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the output contains PII."]
                    ),
                )
            ],
            stage_id=self.protect_stage_id,
        )

        if logger:
            logger.add_protect_span(
                payload=Payload(input=input, output=output),
                response=response,
            )

        return dict(response)