File size: 4,691 Bytes
7f7cb09
46b11f4
 
753e3c5
 
46b11f4
 
753e3c5
 
 
 
 
 
 
e68d535
 
 
 
 
 
 
 
 
 
 
 
753e3c5
e68d535
46b11f4
753e3c5
46b11f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e68d535
46b11f4
e68d535
1d9f42e
 
 
7f7cb09
 
 
 
 
 
9f28d12
753e3c5
e68d535
7f7cb09
e68d535
9f28d12
 
 
 
7f7cb09
e68d535
46b11f4
e68d535
753e3c5
 
e68d535
753e3c5
e68d535
753e3c5
 
 
e68d535
 
 
753e3c5
e68d535
 
 
753e3c5
e68d535
753e3c5
 
 
e68d535
 
 
753e3c5
e68d535
 
 
753e3c5
 
 
 
 
 
 
 
 
 
 
 
e68d535
 
 
 
753e3c5
 
 
e68d535
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pydantic import BaseModel, UUID4
from dotenv import load_dotenv
from typing import Any

from galileo import GalileoLogger, GalileoScorers, StageType
from galileo.log_streams import create_log_stream, get_log_stream
from galileo.projects import create_project, get_project
from galileo.protect import invoke_protect
from galileo.stages import create_protect_stage, get_protect_stage
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from galileo_core.schemas.protect.ruleset import Ruleset

load_dotenv()

class GalileoPlatformConfig(BaseModel):
    """Base configuration for Galileo platform."""
    protect_project_name: str
    protect_stage_name: str

class GalileoPlatform:
    """Implementation of Galileo Features"""

    def __init__(self, config: GalileoPlatformConfig):
        self.config = config
        self.protect_stage_id = self.get_protect_stage_id()

    def get_logger(self, project_name: str, logstream_name: str) -> GalileoLogger | None:
        """Get or create a Galileo Logger."""
        if not get_project(name=project_name):
            _ = create_project(name=project_name)
            print(f"Project {project_name} created")

        if not get_log_stream(name=logstream_name, project_name=project_name):
            _ = create_log_stream(name=logstream_name, project_name=project_name)
            print(f"Logstream {logstream_name} created in project {project_name}")

        try:
            return GalileoLogger(
                project=project_name,
                log_stream=logstream_name,
            )
        except Exception as e:
            print(f"Failed to create logger: {e}")
            return None

    def get_protect_stage_id(self) -> str | UUID4 | None:
        """Get or create a Galileo Protect stage."""
        if not get_project(name=self.config.protect_project_name):
            _ = create_project(name=self.config.protect_project_name)
            print(f"Project {self.config.protect_project_name} created")

        protect_stage = get_protect_stage(
            project_name=self.config.protect_project_name,
            stage_name=self.config.protect_stage_name,
        )
        if not protect_stage:
            _ = create_protect_stage(
                project_name=self.config.protect_project_name,
                name=self.config.protect_stage_name,
                stage_type=StageType.local
            )
            protect_stage = get_protect_stage(
                project_name=self.config.protect_project_name,
                stage_name=self.config.protect_stage_name,
            )
        return protect_stage.id

    def run_protect(self, input: str, output: str, logger: GalileoLogger | None = None) -> dict[Any, Any]:
        """Run Galileo Protect on input and output."""
        response = invoke_protect(
            payload=Payload(input=input, output=output),
            prioritized_rulesets=[
                Ruleset(
                    rules=[
                        Rule(
                            metric=GalileoScorers.context_adherence_luna,
                            operator=RuleOperator.lte,
                            target_value=0.01,
                        ),
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the input is hallucinatory."]
                    ),
                ),
                Ruleset(
                    rules=[
                        Rule(
                            metric=GalileoScorers.input_pii,
                            operator=RuleOperator.any,
                            target_value=["email", "phone_number", "name"],
                        )
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the output contains PII."]
                    ),
                ),
                Ruleset(
                    rules=[
                        Rule(
                            metric="deutsche_bank_company_pii_scorer_0",
                            operator=RuleOperator.gte,
                            target_value=0.1,
                        )
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the output contains PII."]
                    ),
                )
            ],
            stage_id=self.protect_stage_id,
        )

        if logger:
            logger.add_protect_span(
                payload=Payload(input=input, output=output),
                response=response,
            )

        return dict(response)