nikhile-galileo commited on
Commit
753e3c5
·
1 Parent(s): eed7bad

Added G2.0 changes

Browse files
backend/api/main.py CHANGED
@@ -33,7 +33,6 @@ templates = Jinja2Templates(directory="backend/api/templates")
33
 
34
  load_dotenv()
35
 
36
-
37
  logger = initialize_logger()
38
 
39
  # get current file path using Path
@@ -54,8 +53,8 @@ embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)
54
 
55
  # Create vector db model object
56
  vector_db_config = MilvusVectorDatabaseConfig(
57
- db_path=app_config["vector_database"]["db_path"],
58
- collection_name=app_config["vector_database"]["collection_name"],
59
  vector_dimensions=app_config["vector_database"]["dimensions"],
60
  drop_if_exists=False,
61
  )
@@ -63,25 +62,28 @@ vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)
63
 
64
  # Create generative model object
65
  gemini_generative_model_config = GeminiModelConfig(
66
- model_name=app_config["gemini_generative_model"]["model_name"],
67
  api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]],
68
- temperature=app_config["gemini_generative_model"]["temperature"],
69
  )
70
  gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config)
71
 
72
  openai_generative_model_config = OpenAIModelConfig(
73
- model_name=app_config["openai_generative_model"]["model_name"],
74
  api_key=env_variables["OPENAI_API_KEY"],
75
- temperature=app_config["openai_generative_model"]["temperature"],
76
  )
77
  openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config)
78
 
 
 
 
 
 
79
  # Create Galileo platform object
80
  galileo_platform_config = GalileoPlatformConfig(
81
- evaluate_project_name=app_config["galileo_platform"]["evaluate_project_name"],
82
- observe_project_name=app_config["galileo_platform"]["observe_project_name"],
83
- protect_project_name=app_config["galileo_platform"]["protect_project_name"],
84
- protect_stage_name=app_config["galileo_platform"]["protect_stage_name"],
85
  )
86
  galileo_platform = GalileoPlatform(galileo_platform_config)
87
 
@@ -89,8 +91,8 @@ galileo_platform = GalileoPlatform(galileo_platform_config)
89
  rag_application_config = RAGApplicationConfig(
90
  embedding_model=embedding_model,
91
  vector_db=vector_db,
92
- # gemini_generative_model=gemini_generative_model,
93
- generative_model=openai_generative_model,
94
  galileo_platform=galileo_platform,
95
  )
96
  rag_app = RAGApplication(rag_application_config)
@@ -98,27 +100,36 @@ rag_app = RAGApplication(rag_application_config)
98
 
99
  @app.get("/", response_class=HTMLResponse)
100
  async def read_root(request: Request):
101
- return templates.TemplateResponse("index.html", {"request": request})
102
-
103
- # TODO: Nikhil
104
- # @app.post("/other-metrics")
105
- # async def search(
106
-
 
107
 
108
  @app.post("/search")
109
  async def search(
110
  query: str = Form(...),
111
  top_k: int = Form(5),
 
112
  protection: bool = Form(False),
113
  hallucination_detection: bool = Form(False),
114
  induce_hallucination: bool = Form(False),
 
 
 
115
  ):
 
116
  response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run(
117
  query,
118
  protect_enabled=protection,
119
  top_k=top_k,
120
  hallucination_detection=hallucination_detection,
121
  induce_hallucination=induce_hallucination,
 
 
 
122
  )
123
 
124
  # Simulate processing
 
33
 
34
  load_dotenv()
35
 
 
36
  logger = initialize_logger()
37
 
38
  # get current file path using Path
 
53
 
54
  # Create vector db model object
55
  vector_db_config = MilvusVectorDatabaseConfig(
56
+ db_path=app_config["vector_database"]["db_path"] + env_variables["MILVUS_DB"] + "_milvus.db",
57
+ collection_name=env_variables["MILVUS_DB"],
58
  vector_dimensions=app_config["vector_database"]["dimensions"],
59
  drop_if_exists=False,
60
  )
 
62
 
63
  # Create generative model object
64
  gemini_generative_model_config = GeminiModelConfig(
65
+ model_name=env_variables["GOOGLE_GEMINI_MODEL"],
66
  api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]],
67
+ temperature=int(env_variables["MODEL_TEMPERATURE"]),
68
  )
69
  gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config)
70
 
71
  openai_generative_model_config = OpenAIModelConfig(
72
+ model_name=env_variables["OPENAI_MODEL"],
73
  api_key=env_variables["OPENAI_API_KEY"],
74
+ temperature=int(env_variables["MODEL_TEMPERATURE"]),
75
  )
76
  openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config)
77
 
78
+ default_project_name = env_variables["GALILEO_PROJECT_NAME"]
79
+ default_logstream_name = env_variables["GALILEO_LOGSTREAM_NAME"]
80
+ default_protect_stage_name = env_variables["GALILEO_PROTECT_STAGE_NAME"]
81
+ default_dataset_name = env_variables["GALILEO_DATASET_NAME"]
82
+
83
  # Create Galileo platform object
84
  galileo_platform_config = GalileoPlatformConfig(
85
+ protect_project_name=env_variables["GALILEO_PROJECT_NAME"],
86
+ protect_stage_name=default_protect_stage_name,
 
 
87
  )
88
  galileo_platform = GalileoPlatform(galileo_platform_config)
89
 
 
91
  rag_application_config = RAGApplicationConfig(
92
  embedding_model=embedding_model,
93
  vector_db=vector_db,
94
+ generative_model=gemini_generative_model,
95
+ # generative_model=openai_generative_model,
96
  galileo_platform=galileo_platform,
97
  )
98
  rag_app = RAGApplication(rag_application_config)
 
100
 
101
  @app.get("/", response_class=HTMLResponse)
102
  async def read_root(request: Request):
103
+ # Get default project name from environment variables
104
+ return templates.TemplateResponse("index.html", {
105
+ "request": request,
106
+ "default_project_name": default_project_name,
107
+ "default_logstream_name": default_logstream_name,
108
+ "default_dataset_name": default_dataset_name
109
+ })
110
 
111
  @app.post("/search")
112
  async def search(
113
  query: str = Form(...),
114
  top_k: int = Form(5),
115
+ add_to_dataset: bool = Form(False),
116
  protection: bool = Form(False),
117
  hallucination_detection: bool = Form(False),
118
  induce_hallucination: bool = Form(False),
119
+ project_name: str = default_project_name,
120
+ logstream_name: str = default_logstream_name,
121
+ dataset_name: str = default_dataset_name,
122
  ):
123
+
124
  response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run(
125
  query,
126
  protect_enabled=protection,
127
  top_k=top_k,
128
  hallucination_detection=hallucination_detection,
129
  induce_hallucination=induce_hallucination,
130
+ project_name=project_name,
131
+ logstream_name=logstream_name,
132
+ dataset_name=dataset_name if add_to_dataset else None,
133
  )
134
 
135
  # Simulate processing
backend/classes/galileo_platform.py CHANGED
@@ -1,116 +1,102 @@
1
- from galileo_observe import ObserveWorkflows
2
- import galileo_protect as gp
3
- from pydantic import BaseModel
4
- import promptquality as pq
5
- from promptquality import CustomizedScorerName, Models
6
- from dotenv import load_dotenv
7
- import os
8
- from datetime import datetime
9
  from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
10
  load_dotenv()
11
 
12
  class GalileoPlatformConfig(BaseModel):
13
  """Base configuration for Galileo platform."""
14
- evaluate_project_name: str
15
- observe_project_name: str
16
  protect_project_name: str
17
  protect_stage_name: str
18
 
19
-
20
  class GalileoPlatform:
21
  """Implementation of Galileo Features"""
22
 
23
  def __init__(self, config: GalileoPlatformConfig):
24
  self.config = config
25
- pq.login(api_key=os.getenv("GALILEO_API_KEY"))
26
- self.evaluate_run = self.create_evaluate_run()
27
- self.observe_logger = ObserveWorkflows(project_name=config.observe_project_name)
28
- self.protect_stage_id = self.get_protect_stage()
29
 
30
- def create_evaluate_run(self):
31
- """Create a Galileo Evaluate run."""
32
- scorers = [
33
- pq.Scorers.context_adherence_luna,
34
- pq.Scorers.chunk_attribution_utilization_luna,
35
- pq.Scorers.completeness_luna
36
- ]
37
- evaluate_run = pq.EvaluateRun(
38
- project_name=self.config.evaluate_project_name,
39
- scorers=scorers,
40
  )
41
- return evaluate_run
42
 
43
- def get_protect_stage(self):
44
  """Get or create a Galileo Protect stage."""
45
  try:
46
- protect_project = gp.get_project(
47
- project_name=self.config.protect_project_name
 
48
  )
 
49
  except Exception as _:
50
- protect_project = gp.create_project(name=self.config.protect_project_name)
51
-
52
- protect_project_id = protect_project.id
53
-
54
- try:
55
- protect_stage = gp.get_stage(
56
- project_id=protect_project_id, stage_name=self.config.protect_stage_name
57
- )
58
- except Exception as _:
59
- protect_stage = gp.create_stage(
60
- project_id=protect_project_id,
61
  name=self.config.protect_stage_name,
 
 
62
  )
 
63
 
64
- return protect_stage.id
65
-
66
- def run_protect(self, prompt: str, output: str, workflow: Optional[ObserveWorkflows] = None) -> dict:
67
  """Run Galileo Protect on input and output."""
68
- response = gp.invoke(
69
- payload=gp.Payload(input=prompt, output=output),
70
  prioritized_rulesets=[
71
- gp.Ruleset(
72
  rules=[
73
- gp.Rule(
74
- metric=gp.RuleMetrics.context_adherence_luna,
75
- operator=gp.RuleOperator.lte,
76
  target_value=0.01,
77
  ),
78
  ],
79
- action=gp.OverrideAction(
80
  choices=["Sorry, the input is hallucinatory."]
81
  ),
82
  ),
83
- gp.Ruleset(
84
  rules=[
85
- gp.Rule(
86
- metric=gp.RuleMetrics.pii,
87
- operator=gp.RuleOperator.any,
88
  target_value=["email", "phone_number", "name"],
89
  )
90
  ],
91
- action=gp.OverrideAction(
92
  choices=["Sorry, the output contains PII."]
93
  ),
94
  ),
95
- # gp.Ruleset(
96
- # rules=[
97
- # gp.Rule(
98
- # metric="deutsche_bank_company_pii_0",
99
- # operator=gp.RuleOperator.gte,
100
- # target_value=0.1,
101
- # )
102
- # ],
103
- # action=gp.OverrideAction(
104
- # choices=["Sorry, the output contains PII."]
105
- # ),
106
- # )
107
  ],
108
  stage_id=self.protect_stage_id,
109
  )
110
 
111
- if workflow:
112
- workflow.add_protect(
113
- payload=gp.Payload(input=prompt, output=output),
114
  response=response,
115
  )
116
 
 
 
 
 
 
 
 
 
 
1
  from typing import Optional
2
+
3
+ from dotenv import load_dotenv
4
+ from pydantic import BaseModel
5
+
6
+ from galileo import GalileoLogger, GalileoScorers, StageType
7
+ from galileo.protect import invoke_protect
8
+ from galileo.stages import create_protect_stage, get_protect_stage
9
+ from galileo_core.schemas.protect.action import OverrideAction
10
+ from galileo_core.schemas.protect.payload import Payload
11
+ from galileo_core.schemas.protect.rule import Rule, RuleOperator
12
+ from galileo_core.schemas.protect.ruleset import Ruleset
13
+
14
  load_dotenv()
15
 
16
  class GalileoPlatformConfig(BaseModel):
17
  """Base configuration for Galileo platform."""
 
 
18
  protect_project_name: str
19
  protect_stage_name: str
20
 
 
21
  class GalileoPlatform:
22
  """Implementation of Galileo Features"""
23
 
24
  def __init__(self, config: GalileoPlatformConfig):
25
  self.config = config
26
+ self.protect_stage_id = self.get_protect_stage_id()
 
 
 
27
 
28
+ def get_logger(self, project_name: str, logstream_name: str):
29
+ """Get or create a Galileo Logger."""
30
+ return GalileoLogger(
31
+ project=project_name,
32
+ log_stream=logstream_name,
 
 
 
 
 
33
  )
 
34
 
35
+ def get_protect_stage_id(self):
36
  """Get or create a Galileo Protect stage."""
37
  try:
38
+ protect_stage = get_protect_stage(
39
+ project_name=self.config.protect_project_name,
40
+ stage_name=self.config.protect_stage_name,
41
  )
42
+ return protect_stage.id
43
  except Exception as _:
44
+ protect_stage = create_protect_stage(
45
+ project_name=self.config.protect_project_name,
 
 
 
 
 
 
 
 
 
46
  name=self.config.protect_stage_name,
47
+ stage_type=StageType.local,
48
+ description="Deutsche Bank RFP RAG Protect Stage"
49
  )
50
+ return protect_stage.id
51
 
52
+ def run_protect(self, input: str, output: str, logger: Optional[GalileoLogger] = None) -> dict:
 
 
53
  """Run Galileo Protect on input and output."""
54
+ response = invoke_protect(
55
+ payload=Payload(input=input, output=output),
56
  prioritized_rulesets=[
57
+ Ruleset(
58
  rules=[
59
+ Rule(
60
+ metric=GalileoScorers.context_adherence_luna,
61
+ operator=RuleOperator.lte,
62
  target_value=0.01,
63
  ),
64
  ],
65
+ action=OverrideAction(
66
  choices=["Sorry, the input is hallucinatory."]
67
  ),
68
  ),
69
+ Ruleset(
70
  rules=[
71
+ Rule(
72
+ metric=GalileoScorers.input_pii,
73
+ operator=RuleOperator.any,
74
  target_value=["email", "phone_number", "name"],
75
  )
76
  ],
77
+ action=OverrideAction(
78
  choices=["Sorry, the output contains PII."]
79
  ),
80
  ),
81
+ Ruleset(
82
+ rules=[
83
+ Rule(
84
+ metric="deutsche_bank_company_pii_scorer_0",
85
+ operator=RuleOperator.gte,
86
+ target_value=0.1,
87
+ )
88
+ ],
89
+ action=OverrideAction(
90
+ choices=["Sorry, the output contains PII."]
91
+ ),
92
+ )
93
  ],
94
  stage_id=self.protect_stage_id,
95
  )
96
 
97
+ if logger:
98
+ logger.add_protect_span(
99
+ payload=Payload(input=input, output=output),
100
  response=response,
101
  )
102
 
backend/classes/rag_application.py CHANGED
@@ -1,13 +1,13 @@
1
- from pydantic import BaseModel
2
- import json
3
  import time
4
- import re
5
- from promptquality import Models
 
 
 
6
  from backend.classes.embedding_model import EmbeddingModel
7
- from backend.classes.vector_database.milvus_vector_database import MilvusVectorDatabase
8
  from backend.classes.galileo_platform import GalileoPlatform
9
  from backend.classes.generative_model import GeminiModel, OpenAIModel
10
- from typing import Union
11
 
12
  def strike(text):
13
  return ''.join([char + '\u0336' for char in text])
@@ -43,6 +43,7 @@ The following are the categories that need to be redacted:
43
  - Phone numbers
44
  - Email addresses
45
  - Names
 
46
  For every PII that needs to be redacted, wrap it in <pii></pii> tags.
47
 
48
  Categories: {pii_flag}
@@ -50,7 +51,7 @@ Response: {response}
50
  Modified Response: """
51
 
52
 
53
- hallucinatory_chunks: list[str] = [
54
  "Fairfield CDC is issuing this RFP to select a banking partner for its ambitious new program to fund the city's first dragon-powered public transportation system.",
55
  "Merchant services must include psychic energy transfer gateways for multi-reality donation collection.",
56
  "Technological capabilities must include temporal online banking for pre-cognitive transaction approvals.",
@@ -75,14 +76,15 @@ class RAGApplication:
75
  top_k: int = 5,
76
  hallucination_detection: bool = False,
77
  induce_hallucination: bool = False,
 
 
 
78
  ) -> str:
79
- # Create a workflow to track this query
80
- observe_workflow = self.config.galileo_platform.observe_logger.add_workflow(
81
- name="RAG Workflow", input={"query": query}
82
- )
83
 
84
- evaluate_workflow = self.config.galileo_platform.evaluate_run.add_workflow(
85
- name="RAG Workflow", input={"query": query}
86
  )
87
 
88
  context_adherence_score = 1
@@ -97,10 +99,8 @@ class RAGApplication:
97
  try:
98
  start_time = time.time()
99
 
100
- # Get query embedding
101
  query_embedding = self.config.embedding_model.encode([query])
102
 
103
- # Get top-k similar texts
104
  retrieved_documents = [
105
  str(text["text"])
106
  for text in self.config.vector_db.search_similar_texts(
@@ -108,42 +108,27 @@ class RAGApplication:
108
  )
109
  ]
110
 
111
- # Log retriever step to Galileo Observe
112
- observe_workflow.add_retriever(
113
  name="Milvus Retrieval",
114
  input=query,
115
- documents=retrieved_documents,
116
  duration_ns=int((time.time() - start_time) * 1e9),
117
  )
118
-
119
- evaluate_workflow.add_retriever(
120
- name="Milvus Retrieval",
121
- input=query,
122
- documents=retrieved_documents,
123
- # documents=[
124
- # Document(content=doc, metadata={"length": len(doc)}) for doc in retrieved_documents],
125
- duration_ns=int((time.time() - start_time) * 1e9),
126
- )
127
-
128
  start_time = time.time()
129
 
130
  if not retrieved_documents:
131
  return "There is nothing to return", redacted_result, context_adherence_score, pii_flag
132
 
133
- # Create context by combining the retrieved documents
134
  context = "\n\n".join(retrieved_documents)
135
 
136
- # Set prompt template
137
  prompt = (
138
  self.config.prompt_template
139
  if not prompt_template
140
  else prompt_template
141
  )
142
 
143
- # Construct prompt
144
  formatted_prompt = f"{prompt}\n\nQUESTION: {query}\n\nCONTEXT: {context}"
145
 
146
- # Generate response
147
  result = self.config.generative_model.generate_response(
148
  formatted_prompt
149
  )
@@ -156,59 +141,63 @@ class RAGApplication:
156
  temperature=1.0,
157
  )
158
 
159
- # Log LLM call to Galileo Observe
160
- observe_workflow.add_llm(
 
 
 
 
161
  name="Answer Generation",
162
- input=retrieved_documents,
163
  output=result,
164
  model=self.config.generative_model.config.model_name,
165
  duration_ns=int((time.time() - start_time) * 1e9),
 
 
 
 
166
  )
167
 
168
- evaluate_workflow.add_llm(
169
- # input=Message(content=prompt, role=MessageRole.user),
170
- # output=Message(content=result, role=MessageRole.assistant),
171
- name="Answer Generation",
172
- input=prompt,
173
- output=result,
174
- model=Models.gpt_4o,
175
- duration_ns=int((time.time() - start_time) * 1e9),
176
- )
 
 
177
 
178
  start_time = time.time()
179
 
180
  protect_response = self.config.galileo_platform.run_protect(
181
- context, result, observe_workflow
182
  )
183
 
 
 
184
  if protect_enabled and protect_response["text"] != result:
185
  pii_flag["phone_number"] = "phone_number" in protect_response["metric_results"]["pii"]["value"]
186
  pii_flag["email"] = "email" in protect_response["metric_results"]["pii"]["value"]
187
  pii_flag["name"] = "name" in protect_response["metric_results"]["pii"]["value"]
188
- # pii_flag["company"] = protect_response["metric_results"]["deutsche_bank_company_pii_0"]["value"]>0.1
189
  redacted_result = self.get_redacted_result(result, pii_flag)
 
190
  result = redacted_result.replace("<pii>", "<tag>").replace("</pii>", "</tag>")
191
- redacted_result = re.sub(r'<pii>(.*?)</pii>', r'<pii>REDACTED</pii>', redacted_result)
192
 
193
  if hallucination_detection:
194
  context_adherence_score = protect_response["metric_results"]["context_adherence_luna"]["value"]
195
- # print(context_adherence_score)
196
-
197
- # Conclude the workflow with the final result and set output
198
- observe_workflow.conclude(output=result)
199
- evaluate_workflow.output = result
200
- self.config.galileo_platform.observe_logger.upload_workflows()
201
 
202
- # Start evaluation in separate thread
203
- self.config.galileo_platform.evaluate_run.finish(wait=True, silent=True)
204
- # print(self.config.galileo_platform.evaluate_run)
205
 
206
  return result, redacted_result, original_result, context_adherence_score, pii_flag
207
 
208
  except Exception as e:
209
- # Log errors to Galileo Observe
210
- observe_workflow.conclude(output={"error": str(e)})
211
- self.config.galileo_platform.observe_logger.upload_workflows()
212
  raise e
213
 
214
  def get_redacted_result(self, result, pii_flag):
 
 
 
1
  import time
2
+ from typing import List, Union
3
+
4
+ from galileo.datasets import create_dataset, get_dataset
5
+ from pydantic import BaseModel
6
+
7
  from backend.classes.embedding_model import EmbeddingModel
 
8
  from backend.classes.galileo_platform import GalileoPlatform
9
  from backend.classes.generative_model import GeminiModel, OpenAIModel
10
+ from backend.classes.vector_database.milvus_vector_database import MilvusVectorDatabase
11
 
12
  def strike(text):
13
  return ''.join([char + '\u0336' for char in text])
 
43
  - Phone numbers
44
  - Email addresses
45
  - Names
46
+ - Company names (Fairfield or Fairfield CDC or other variations)
47
  For every PII that needs to be redacted, wrap it in <pii></pii> tags.
48
 
49
  Categories: {pii_flag}
 
51
  Modified Response: """
52
 
53
 
54
+ hallucinatory_chunks: List[str] = [
55
  "Fairfield CDC is issuing this RFP to select a banking partner for its ambitious new program to fund the city's first dragon-powered public transportation system.",
56
  "Merchant services must include psychic energy transfer gateways for multi-reality donation collection.",
57
  "Technological capabilities must include temporal online banking for pre-cognitive transaction approvals.",
 
76
  top_k: int = 5,
77
  hallucination_detection: bool = False,
78
  induce_hallucination: bool = False,
79
+ project_name: str = None,
80
+ logstream_name: str = None,
81
+ dataset_name: str = None,
82
  ) -> str:
83
+
84
+ galileo_logger = self.config.galileo_platform.get_logger(project_name, logstream_name)
 
 
85
 
86
+ _ = galileo_logger.start_trace(
87
+ name="RAG Workflow", input=query
88
  )
89
 
90
  context_adherence_score = 1
 
99
  try:
100
  start_time = time.time()
101
 
 
102
  query_embedding = self.config.embedding_model.encode([query])
103
 
 
104
  retrieved_documents = [
105
  str(text["text"])
106
  for text in self.config.vector_db.search_similar_texts(
 
108
  )
109
  ]
110
 
111
+ galileo_logger.add_retriever_span(
 
112
  name="Milvus Retrieval",
113
  input=query,
114
+ output=retrieved_documents,
115
  duration_ns=int((time.time() - start_time) * 1e9),
116
  )
 
 
 
 
 
 
 
 
 
 
117
  start_time = time.time()
118
 
119
  if not retrieved_documents:
120
  return "There is nothing to return", redacted_result, context_adherence_score, pii_flag
121
 
 
122
  context = "\n\n".join(retrieved_documents)
123
 
 
124
  prompt = (
125
  self.config.prompt_template
126
  if not prompt_template
127
  else prompt_template
128
  )
129
 
 
130
  formatted_prompt = f"{prompt}\n\nQUESTION: {query}\n\nCONTEXT: {context}"
131
 
 
132
  result = self.config.generative_model.generate_response(
133
  formatted_prompt
134
  )
 
141
  temperature=1.0,
142
  )
143
 
144
+ input_data = {
145
+ "question": query,
146
+ "context": context,
147
+ }
148
+
149
+ galileo_logger.add_llm_span(
150
  name="Answer Generation",
151
+ input=input_data,
152
  output=result,
153
  model=self.config.generative_model.config.model_name,
154
  duration_ns=int((time.time() - start_time) * 1e9),
155
+ metadata={
156
+ "question": query,
157
+ "context": context,
158
+ }
159
  )
160
 
161
+ try:
162
+ row = {
163
+ "input": input_data,
164
+ "output": result,
165
+ }
166
+ if dataset_name:
167
+ dataset = get_dataset(name=dataset_name)
168
+ dataset.add_rows([row])
169
+ except Exception as e:
170
+ print(e)
171
+ dataset = create_dataset(name=dataset_name, content=[row])
172
 
173
  start_time = time.time()
174
 
175
  protect_response = self.config.galileo_platform.run_protect(
176
+ context, result, galileo_logger
177
  )
178
 
179
+ print(protect_response)
180
+
181
  if protect_enabled and protect_response["text"] != result:
182
  pii_flag["phone_number"] = "phone_number" in protect_response["metric_results"]["pii"]["value"]
183
  pii_flag["email"] = "email" in protect_response["metric_results"]["pii"]["value"]
184
  pii_flag["name"] = "name" in protect_response["metric_results"]["pii"]["value"]
185
+ pii_flag["company"] = protect_response["metric_results"]["deutsche_bank_company_pii_0"]["value"]>0.1
186
  redacted_result = self.get_redacted_result(result, pii_flag)
187
+ # redacted_result = re.sub(r'<pii>(.*?)</pii>', r'<pii>REDACTED</pii>', redacted_result)
188
  result = redacted_result.replace("<pii>", "<tag>").replace("</pii>", "</tag>")
 
189
 
190
  if hallucination_detection:
191
  context_adherence_score = protect_response["metric_results"]["context_adherence_luna"]["value"]
 
 
 
 
 
 
192
 
193
+ galileo_logger.conclude(output=result)
194
+ galileo_logger.flush()
 
195
 
196
  return result, redacted_result, original_result, context_adherence_score, pii_flag
197
 
198
  except Exception as e:
199
+ galileo_logger.conclude(output={"error": str(e)})
200
+ galileo_logger.flush()
 
201
  raise e
202
 
203
  def get_redacted_result(self, result, pii_flag):
backend/classes/vector_database/milvus_vector_database.py CHANGED
@@ -3,7 +3,7 @@ import shutil
3
  from typing import List
4
 
5
  import pandas as pd
6
- from pymilvus import MilvusClient, connections, FieldSchema, CollectionSchema, DataType, Collection
7
  import logging
8
 
9
  from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase
 
3
  from typing import List
4
 
5
  import pandas as pd
6
+ from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
7
  import logging
8
 
9
  from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase