qyle commited on
Commit
fc62e60
·
verified ·
1 Parent(s): 8fe7ab1

deployment

Browse files
champ/agent.py CHANGED
@@ -8,8 +8,6 @@ from langchain_community.vectorstores import FAISS as LCFAISS
8
 
9
  from opentelemetry import trace
10
 
11
- from classes.prompt_sanitizer import PromptSanitizer
12
-
13
  from .prompts import CHAMP_SYSTEM_PROMPT_V5
14
 
15
  tracer = trace.get_tracer(__name__)
@@ -33,6 +31,8 @@ def _build_retrieval_query(messages) -> str:
33
  def make_prompt_with_context(
34
  vector_store: LCFAISS, lang: Literal["en", "fr"], k: int = 4
35
  ):
 
 
36
  @dynamic_prompt
37
  def prompt_with_context(request: ModelRequest) -> str:
38
  with tracer.start_as_current_span("retrieving documents"):
@@ -58,23 +58,17 @@ def make_prompt_with_context(
58
  unique_docs.append(doc)
59
 
60
  docs_content = "\n\n".join(doc.page_content for doc in unique_docs)
61
-
62
- # No need to sanitize the docs_content as the documents are sanitized
63
- # when received at the file PUT endpoint.
64
- with tracer.start_as_current_span("PromptSanitizer"):
65
- sanitizer = PromptSanitizer()
66
- with tracer.start_as_current_span("sanitize retrieval_query"):
67
- sanitized_retrieval_query = sanitizer.sanitize(retrieval_query)
68
 
69
  language = "English" if lang == "en" else "French"
70
 
71
  return CHAMP_SYSTEM_PROMPT_V5.format(
72
- last_query=sanitized_retrieval_query,
73
  context=docs_content,
74
  language=language,
75
  )
76
 
77
- return prompt_with_context
78
 
79
 
80
  def build_champ_agent(
@@ -91,11 +85,11 @@ def build_champ_agent(
91
  # huggingfacehub_api_token=... (optional; see service.py)
92
  )
93
  model_chat = ChatHuggingFace(llm=hf_llm)
94
- prompt_middleware = make_prompt_with_context(vector_store, lang)
95
  return create_agent(
96
  model_chat,
97
  tools=[],
98
  middleware=[
99
  prompt_middleware,
100
  ],
101
- )
 
8
 
9
  from opentelemetry import trace
10
 
 
 
11
  from .prompts import CHAMP_SYSTEM_PROMPT_V5
12
 
13
  tracer = trace.get_tracer(__name__)
 
31
  def make_prompt_with_context(
32
  vector_store: LCFAISS, lang: Literal["en", "fr"], k: int = 4
33
  ):
34
+ context_store = {"last_retrieved_docs": []} # shared mutable container
35
+
36
  @dynamic_prompt
37
  def prompt_with_context(request: ModelRequest) -> str:
38
  with tracer.start_as_current_span("retrieving documents"):
 
58
  unique_docs.append(doc)
59
 
60
  docs_content = "\n\n".join(doc.page_content for doc in unique_docs)
61
+ context_store["last_retrieved_docs"] = [doc.page_content for doc in unique_docs]
 
 
 
 
 
 
62
 
63
  language = "English" if lang == "en" else "French"
64
 
65
  return CHAMP_SYSTEM_PROMPT_V5.format(
66
+ last_query=retrieval_query,
67
  context=docs_content,
68
  language=language,
69
  )
70
 
71
+ return prompt_with_context, context_store
72
 
73
 
74
  def build_champ_agent(
 
85
  # huggingfacehub_api_token=... (optional; see service.py)
86
  )
87
  model_chat = ChatHuggingFace(llm=hf_llm)
88
+ prompt_middleware, context_store = make_prompt_with_context(vector_store, lang)
89
  return create_agent(
90
  model_chat,
91
  tools=[],
92
  middleware=[
93
  prompt_middleware,
94
  ],
95
+ ), context_store
champ/service.py CHANGED
@@ -1,11 +1,10 @@
1
  # app/champ/service.py
2
 
3
- from typing import Literal, Optional, Sequence
4
 
5
  from langchain_community.vectorstores import FAISS as LCFAISS
6
  from langchain_core.messages import HumanMessage
7
 
8
-
9
  from .agent import build_champ_agent
10
  from .triage import safety_triage
11
 
@@ -14,12 +13,25 @@ class ChampService:
14
  vector_store: Optional[LCFAISS] = None
15
  agent = None
16
  lang = None
 
17
 
18
  def __init__(self, vector_store: LCFAISS, lang: Literal["en", "fr"]):
 
19
  self.vector_store = vector_store
20
- self.agent = build_champ_agent(self.vector_store, lang)
 
 
 
21
 
22
- def invoke(self, lc_messages: Sequence) -> str:
 
 
 
 
 
 
 
 
23
  if self.agent is None:
24
  raise RuntimeError("CHAMP is not initialized yet.")
25
  # --- Safety triage micro-layer (before LLM) ---
@@ -38,6 +50,16 @@ class ChampService:
38
  }
39
 
40
  result = self.agent.invoke({"messages": list(lc_messages)})
41
- return result["messages"][-1].text.strip(), {
42
- "triage_triggered": False,
43
- }
 
 
 
 
 
 
 
 
 
 
 
1
  # app/champ/service.py
2
 
3
+ from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple
4
 
5
  from langchain_community.vectorstores import FAISS as LCFAISS
6
  from langchain_core.messages import HumanMessage
7
 
 
8
  from .agent import build_champ_agent
9
  from .triage import safety_triage
10
 
 
13
  vector_store: Optional[LCFAISS] = None
14
  agent = None
15
  lang = None
16
+ context_store = None
17
 
18
  def __init__(self, vector_store: LCFAISS, lang: Literal["en", "fr"]):
19
+
20
  self.vector_store = vector_store
21
+ self.agent, self.context_store = build_champ_agent(self.vector_store, lang)
22
+
23
+ def invoke(self, lc_messages: Sequence) -> Tuple[str, Dict[str, Any], List[str]]:
24
+ """Invokes the agent.
25
 
26
+ Args:
27
+ lc_messages (Sequence): Sequence of LangChain messages
28
+
29
+ Raises:
30
+ RuntimeError: Raised when the function is called before CHAMP is initialized
31
+
32
+ Returns:
33
+ Tuple[str, Dict[str, Any], List[str]]: The replay, the triage_triggered object and the retrieved passages
34
+ """
35
  if self.agent is None:
36
  raise RuntimeError("CHAMP is not initialized yet.")
37
  # --- Safety triage micro-layer (before LLM) ---
 
50
  }
51
 
52
  result = self.agent.invoke({"messages": list(lc_messages)})
53
+
54
+ retrieved_passages = (
55
+ self.context_store["last_retrieved_docs"]
56
+ if self.context_store is not None
57
+ else []
58
+ )
59
+ return (
60
+ result["messages"][-1].text.strip(),
61
+ {
62
+ "triage_triggered": False,
63
+ },
64
+ retrieved_passages,
65
+ )
classes/pii_filter.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from presidio_analyzer import AnalyzerEngine, Pattern, PatternRecognizer
3
+ from presidio_analyzer.nlp_engine import NlpEngineProvider
4
+ from presidio_anonymizer import AnonymizerEngine
5
+ from presidio_anonymizer.entities import OperatorConfig
6
+
7
+ from lingua import Language, LanguageDetector
8
+
9
+
10
+ def create_ssn_pattern_recognizer():
11
+ # matches 111-111-111, 111 111 111, and 111111111
12
+ ssn_pattern = Pattern(
13
+ name="ssn_pattern", regex=r"\b\d{3}[- ]?\d{3}[- ]?\d{3}\b", score=0.8
14
+ )
15
+ return PatternRecognizer(supported_entity="SSN", patterns=[ssn_pattern])
16
+
17
+
18
+ def create_zip_code_pattern_recognizer():
19
+ zip_code_pattern = Pattern(
20
+ name="zip_code_pattern",
21
+ regex=r"\b[A-Z]\d[A-Z]\s?\d[A-Z]\d\b", # Matches A1A 1A1 and A1A1A1
22
+ score=0.8,
23
+ )
24
+ return PatternRecognizer(supported_entity="ZIP_CODE", patterns=[zip_code_pattern])
25
+
26
+
27
+ def create_street_pattern_recognizer():
28
+ bilingual_street_regex = (
29
+ r"\d+\s+(?:rue|boul|boulevard|av|avenue|place|square|st|street|rd|road|ave|blvd|lane|dr|drive)"
30
+ r"\s+[A-ZÁÀÂÄÇÉÈÊËÍÎÏÓÔÖÚÛÜa-z]+"
31
+ r"(?:\s+[A-ZÁÀÂÄÇÉÈÊËÍÎÏÓÔÖÚÛÜa-z]+)*"
32
+ r"|(?:\d+\s+)?[A-ZÁÀÂÄÇÉÈÊËÍÎÏÓÔÖÚÛÜa-z]+(?:\s+[A-ZÁÀÂÄÇÉÈÊËÍÎÏÓÔÖÚÛÜa-z]+)*"
33
+ r"\s+(?:rue|boul|boulevard|av|avenue|place|square|st|street|rd|road|ave|blvd|lane|dr|drive)\b"
34
+ )
35
+
36
+ street_pattern = Pattern(
37
+ name="street_pattern", regex=bilingual_street_regex, score=0.8
38
+ )
39
+ return PatternRecognizer(
40
+ supported_entity="STREET_ADDRESS", patterns=[street_pattern]
41
+ )
42
+
43
+
44
+ class PIIFilter:
45
+ _instance: Optional["PIIFilter"] = None
46
+ analyzer: AnalyzerEngine
47
+ anonymizer: AnonymizerEngine
48
+ operators: dict
49
+ target_entities: List[str]
50
+
51
+ def __new__(cls):
52
+ if cls._instance is None:
53
+ print("Initializing Presidio Engines (this should happen only once)...")
54
+ cls._instance = super(PIIFilter, cls).__new__(cls)
55
+
56
+ # Define which models to use for which language
57
+ configuration = {
58
+ "nlp_engine_name": "spacy",
59
+ "models": [
60
+ {"lang_code": "en", "model_name": "en_core_web_lg"},
61
+ {"lang_code": "fr", "model_name": "fr_core_news_lg"},
62
+ ],
63
+ }
64
+ provider = NlpEngineProvider(nlp_configuration=configuration)
65
+ nlp_engine = provider.create_engine()
66
+
67
+ cls._instance.analyzer = AnalyzerEngine(nlp_engine=nlp_engine)
68
+
69
+ ssn_pattern_recognizer = create_ssn_pattern_recognizer()
70
+ zip_code_pattern_recognizer = create_zip_code_pattern_recognizer()
71
+ street_pattern_recognizer = create_street_pattern_recognizer()
72
+
73
+ cls._instance.analyzer.registry.add_recognizer(ssn_pattern_recognizer)
74
+ cls._instance.analyzer.registry.add_recognizer(zip_code_pattern_recognizer)
75
+ cls._instance.analyzer.registry.add_recognizer(street_pattern_recognizer)
76
+
77
+ cls._instance.anonymizer = AnonymizerEngine()
78
+
79
+ # Define standard masking rules
80
+ cls._instance.operators = {
81
+ "PERSON": OperatorConfig("replace", {"new_value": "[NAME]"}),
82
+ "EMAIL_ADDRESS": OperatorConfig("replace", {"new_value": "[EMAIL]"}),
83
+ "PHONE_NUMBER": OperatorConfig("replace", {"new_value": "[PHONE]"}),
84
+ "SSN": OperatorConfig("replace", {"new_value": "[SSN]"}),
85
+ "CREDIT_CARD": OperatorConfig(
86
+ "replace", {"new_value": "[CREDIT_CARD]"}
87
+ ),
88
+ "LOCATION": OperatorConfig("replace", {"new_value": "[LOCATION]"}),
89
+ "STREET_ADDRESS": OperatorConfig(
90
+ "replace", {"new_value": "[LOCATION]"}
91
+ ),
92
+ "ZIP_CODE": OperatorConfig("replace", {"new_value": "[LOCATION]"}),
93
+ }
94
+ cls._instance.target_entities = list(cls._instance.operators.keys())
95
+
96
+ return cls._instance
97
+
98
+ def sanitize(self, text: str, language_detector: LanguageDetector) -> str:
99
+ """Analyzes and redacts PII from the given text."""
100
+ if not text:
101
+ return text
102
+
103
+ # Instead of detecting the language, we do PII for both language.
104
+ # This seems to be more effective and faster.
105
+
106
+ # lang = ""
107
+ # detected_lang = language_detector.detect_language_of(text)
108
+
109
+ # if detected_lang == Language.ENGLISH:
110
+ # lang = "en"
111
+ # elif detected_lang == Language.FRENCH:
112
+ # lang = "fr"
113
+ # else:
114
+ # # TODO: Warning, defaulting to english
115
+ # lang = "en"
116
+
117
+ # 2. Detect PII in English
118
+ results_en = self.analyzer.analyze(
119
+ text=text,
120
+ entities=self.target_entities,
121
+ language="en",
122
+ )
123
+
124
+ # 3. Redact PII in English
125
+ anonymized_result_en = self.anonymizer.anonymize(
126
+ text=text,
127
+ analyzer_results=results_en, # pyright: ignore[reportArgumentType]
128
+ operators=self.operators,
129
+ )
130
+
131
+ # 4. Detect PII in French
132
+ results_fr = self.analyzer.analyze(
133
+ text=anonymized_result_en.text,
134
+ entities=self.target_entities,
135
+ language="fr",
136
+ )
137
+
138
+ # 5. Redact PII in French
139
+ anonymized_result_fr = self.anonymizer.anonymize(
140
+ text=anonymized_result_en.text,
141
+ analyzer_results=results_fr, # pyright: ignore[reportArgumentType]
142
+ operators=self.operators,
143
+ )
144
+
145
+ return anonymized_result_fr.text
classes/prompt_injection_filter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ # Taken from https://cheatsheetseries.owasp.org/cheatsheets/LLM_Prompt_Injection_Prevention_Cheat_Sheet.html#primary-defenses
5
+ # Has to work with french and english
6
+ class PromptInjectionFilter:
7
+ def __init__(self):
8
+ self.dangerous_patterns = [
9
+ r"ignore\s+(all\s+)?previous\s+instructions?",
10
+ r"you\s+are\s+now\s+(in\s+)?developer\s+mode",
11
+ r"system\s+override",
12
+ r"reveal\s+prompt",
13
+ ]
14
+
15
+ # Fuzzy matching for typoglycemia attacks
16
+ self.fuzzy_patterns = [
17
+ "ignore",
18
+ "bypass",
19
+ "override",
20
+ "reveal",
21
+ "delete",
22
+ "system",
23
+ ]
24
+
25
+ def detect_injection(self, text: str) -> bool:
26
+ # Standard pattern matching
27
+ if any(
28
+ re.search(pattern, text, re.IGNORECASE)
29
+ for pattern in self.dangerous_patterns
30
+ ):
31
+ return True
32
+
33
+ # Fuzzy matching for misspelled words (typoglycemia defense)
34
+ words = re.findall(r"\b\w+\b", text.lower())
35
+ for word in words:
36
+ for pattern in self.fuzzy_patterns:
37
+ if self._is_similar_word(word, pattern):
38
+ return True
39
+ return False
40
+
41
+ def _is_similar_word(self, word: str, target: str) -> bool:
42
+ """Check if word is a typoglycemia variant of target"""
43
+ if len(word) != len(target) or len(word) < 3:
44
+ return False
45
+ # Same first and last letter, scrambled middle
46
+ return (
47
+ word[0] == target[0]
48
+ and word[-1] == target[-1]
49
+ and sorted(word[1:-1]) == sorted(target[1:-1])
50
+ )
51
+
52
+ def sanitize_input(self, text: str) -> str:
53
+ # Normalize common obfuscations
54
+ text = re.sub(r"\s+", " ", text) # Collapse whitespace
55
+ text = re.sub(r"(.)\1{3,}", r"\1", text) # Remove char repetition
56
+
57
+ for pattern in self.dangerous_patterns:
58
+ text = re.sub(pattern, "[FILTERED]", text, flags=re.IGNORECASE)
59
+ return text
classes/session_conversation_store.py CHANGED
@@ -2,6 +2,13 @@ from typing import Dict, List, Literal
2
 
3
  from classes.base_models import ChatMessage
4
 
 
 
 
 
 
 
 
5
 
6
  class SessionConversationStore:
7
  def __init__(self) -> None:
 
2
 
3
  from classes.base_models import ChatMessage
4
 
5
+ """
6
+ This class should be removed after the demo and all call sites
7
+ migrated to the LangGraph checkpointer. We should use a persistent
8
+ checkpointer (e.g. PostgresSaver or RedisSaver) once the demo is completed.
9
+ For more details: https://docs.langchain.com/oss/python/langchain/short-term-memory
10
+ """
11
+
12
 
13
  class SessionConversationStore:
14
  def __init__(self) -> None:
main.py CHANGED
@@ -34,7 +34,8 @@ from classes.base_models import (
34
  )
35
 
36
  # from classes.guardrail_manager import GuardrailManager
37
- from classes.prompt_sanitizer import PromptSanitizer
 
38
  from classes.session_conversation_store import SessionConversationStore
39
  from classes.session_tracker import SessionTracker
40
  from constants import (
@@ -62,6 +63,8 @@ from google import genai
62
 
63
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
64
 
 
 
65
  from champ.prompts import (
66
  DEFAULT_SYSTEM_PROMPT_V2,
67
  DEFAULT_SYSTEM_PROMPT_WITH_CONTEXT_V2,
@@ -173,33 +176,20 @@ def convert_and_sanitize_messages(
173
  )
174
  )
175
 
176
- sanitizer = PromptSanitizer()
177
-
178
  out = [{"role": "system", "content": system_prompt}]
179
  for m in messages:
180
  if m.role == "system":
181
  continue
182
- out.append(
183
- {
184
- "role": m.role,
185
- # We only sanitize human messages.
186
- "content": m.content
187
- if m.role == "assistant"
188
- else sanitizer.sanitize(m.content),
189
- }
190
- )
191
  return out
192
 
193
 
194
  def convert_and_sanitize_messages_langchain(messages: List[ChatMessage]):
195
  list_chatmessages = []
196
- sanitizer = PromptSanitizer()
197
 
198
  for m in messages[-MAX_HISTORY:]:
199
  if m.role == "user":
200
- list_chatmessages.append(
201
- HumanMessage(content=sanitizer.sanitize(m.content))
202
- )
203
  elif m.role == "assistant":
204
  list_chatmessages.append(AIMessage(content=m.content))
205
  elif m.role == "system":
@@ -241,7 +231,7 @@ def call_llm(
241
  model_type: str,
242
  lang: Literal["en", "fr"],
243
  conversation: List[ChatMessage],
244
- ) -> AsyncGenerator[str, None] | Tuple[str, Dict[str, Any]]:
245
  tracer = trace.get_tracer(__name__)
246
 
247
  if model_type == "champ":
@@ -262,9 +252,9 @@ def call_llm(
262
  msgs = convert_and_sanitize_messages_langchain(conversation)
263
 
264
  with tracer.start_as_current_span("invoke"):
265
- reply, triage_meta = champ.invoke(msgs)
266
 
267
- return reply, triage_meta
268
 
269
  if model_type not in MODEL_MAP:
270
  raise ValueError(f"Unknown model_type: {model_type}")
@@ -279,10 +269,10 @@ def call_llm(
279
  return _call_openai(model_id, msgs)
280
 
281
  if model_type == "google-conservative":
282
- return _call_gemini(model_id, msgs, temperature=0.2), {}
283
 
284
  if model_type == "google-creative":
285
- return _call_gemini(model_id, msgs, temperature=1.0), {}
286
 
287
  # If you later add HF models via hf_client, handle here.
288
  raise ValueError(f"Unhandled model_type: {model_type}")
@@ -297,9 +287,14 @@ async def lifespan(app: FastAPI):
297
  # We are loading the OCR Reader in advance, because loading the model takes time.
298
  app.state.ocr_reader = easyocr.Reader(["en", "fr"], gpu=torch.cuda.is_available())
299
 
 
 
 
 
 
300
  # Idem for the prompt sanitizer. No need to store it in the state since this
301
  # class follows the Singleton design pattern.
302
- PromptSanitizer()
303
 
304
  bg_task = asyncio.create_task(cleanup_loop())
305
  yield
@@ -350,8 +345,19 @@ async def chat_endpoint(
350
 
351
  session_tracker.update_session(session_id)
352
 
 
 
 
 
 
 
 
 
 
 
 
353
  session_conversation_store.add_human_message(
354
- session_id, payload.conversation_id, payload.human_message
355
  )
356
  conversation = session_conversation_store.get_conversation(
357
  session_id, conversation_id
@@ -359,6 +365,7 @@ async def chat_endpoint(
359
 
360
  reply = ""
361
  triage_meta = {}
 
362
 
363
  try:
364
  loop = asyncio.get_running_loop()
@@ -405,7 +412,7 @@ async def chat_endpoint(
405
 
406
  return StreamingResponse(logging_wrapper(), media_type="text/event-stream")
407
 
408
- reply, triage_meta = result
409
 
410
  except Exception as e:
411
  background_tasks.add_task(
@@ -426,6 +433,7 @@ async def chat_endpoint(
426
  },
427
  )
428
 
 
429
  background_tasks.add_task(
430
  log_event,
431
  user_id=payload.user_id,
@@ -435,6 +443,7 @@ async def chat_endpoint(
435
  "consent": payload.consent,
436
  "human_message": payload.human_message,
437
  "reply": reply,
 
438
  "age_group": payload.age_group,
439
  "gender": payload.gender,
440
  "roles": payload.roles,
@@ -554,11 +563,17 @@ async def upload_file(
554
  if file_text is None:
555
  return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR)
556
 
557
- sanitizer = PromptSanitizer()
558
- sanitized_file_text = sanitizer.sanitize(file_text)
 
 
 
 
 
 
559
 
560
  if session_document_store.create_document(
561
- session_id, sanitized_file_text, file_name, file_size
562
  ):
563
  session_tracker.update_session(session_id)
564
  else:
 
34
  )
35
 
36
  # from classes.guardrail_manager import GuardrailManager
37
+ from classes.pii_filter import PIIFilter
38
+ from classes.prompt_injection_filter import PromptInjectionFilter
39
  from classes.session_conversation_store import SessionConversationStore
40
  from classes.session_tracker import SessionTracker
41
  from constants import (
 
63
 
64
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
65
 
66
+ from lingua import Language, LanguageDetectorBuilder
67
+
68
  from champ.prompts import (
69
  DEFAULT_SYSTEM_PROMPT_V2,
70
  DEFAULT_SYSTEM_PROMPT_WITH_CONTEXT_V2,
 
176
  )
177
  )
178
 
 
 
179
  out = [{"role": "system", "content": system_prompt}]
180
  for m in messages:
181
  if m.role == "system":
182
  continue
183
+ out.append({"role": m.role, "content": m.content})
 
 
 
 
 
 
 
 
184
  return out
185
 
186
 
187
  def convert_and_sanitize_messages_langchain(messages: List[ChatMessage]):
188
  list_chatmessages = []
 
189
 
190
  for m in messages[-MAX_HISTORY:]:
191
  if m.role == "user":
192
+ list_chatmessages.append(HumanMessage(content=m.content))
 
 
193
  elif m.role == "assistant":
194
  list_chatmessages.append(AIMessage(content=m.content))
195
  elif m.role == "system":
 
231
  model_type: str,
232
  lang: Literal["en", "fr"],
233
  conversation: List[ChatMessage],
234
+ ) -> AsyncGenerator[str, None] | Tuple[str, Dict[str, Any], List[str]]:
235
  tracer = trace.get_tracer(__name__)
236
 
237
  if model_type == "champ":
 
252
  msgs = convert_and_sanitize_messages_langchain(conversation)
253
 
254
  with tracer.start_as_current_span("invoke"):
255
+ reply, triage_meta, context = champ.invoke(msgs)
256
 
257
+ return reply, triage_meta, context
258
 
259
  if model_type not in MODEL_MAP:
260
  raise ValueError(f"Unknown model_type: {model_type}")
 
269
  return _call_openai(model_id, msgs)
270
 
271
  if model_type == "google-conservative":
272
+ return _call_gemini(model_id, msgs, temperature=0.2), {}, []
273
 
274
  if model_type == "google-creative":
275
+ return _call_gemini(model_id, msgs, temperature=1.0), {}, []
276
 
277
  # If you later add HF models via hf_client, handle here.
278
  raise ValueError(f"Unhandled model_type: {model_type}")
 
287
  # We are loading the OCR Reader in advance, because loading the model takes time.
288
  app.state.ocr_reader = easyocr.Reader(["en", "fr"], gpu=torch.cuda.is_available())
289
 
290
+ languages = [Language.ENGLISH, Language.FRENCH]
291
+ app.state.language_detector = LanguageDetectorBuilder.from_languages(
292
+ *languages
293
+ ).build()
294
+
295
  # Idem for the prompt sanitizer. No need to store it in the state since this
296
  # class follows the Singleton design pattern.
297
+ PIIFilter()
298
 
299
  bg_task = asyncio.create_task(cleanup_loop())
300
  yield
 
345
 
346
  session_tracker.update_session(session_id)
347
 
348
+ prompt_injection_filter = PromptInjectionFilter()
349
+ injection_filtered_msg = prompt_injection_filter.sanitize_input(
350
+ payload.human_message
351
+ )
352
+
353
+ pii_filter = PIIFilter()
354
+ with tracer.start_as_current_span("sanitize_document"):
355
+ pii_filtered_msg = pii_filter.sanitize(
356
+ injection_filtered_msg, app.state.language_detector
357
+ )
358
+
359
  session_conversation_store.add_human_message(
360
+ session_id, payload.conversation_id, pii_filtered_msg
361
  )
362
  conversation = session_conversation_store.get_conversation(
363
  session_id, conversation_id
 
365
 
366
  reply = ""
367
  triage_meta = {}
368
+ context = []
369
 
370
  try:
371
  loop = asyncio.get_running_loop()
 
412
 
413
  return StreamingResponse(logging_wrapper(), media_type="text/event-stream")
414
 
415
+ reply, triage_meta, context = result
416
 
417
  except Exception as e:
418
  background_tasks.add_task(
 
433
  },
434
  )
435
 
436
+ # Ajouter les passages récupérés
437
  background_tasks.add_task(
438
  log_event,
439
  user_id=payload.user_id,
 
443
  "consent": payload.consent,
444
  "human_message": payload.human_message,
445
  "reply": reply,
446
+ "context": context,
447
  "age_group": payload.age_group,
448
  "gender": payload.gender,
449
  "roles": payload.roles,
 
563
  if file_text is None:
564
  return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR)
565
 
566
+ prompt_injection_filter = PromptInjectionFilter()
567
+ injection_filtered_file_text = prompt_injection_filter.sanitize_input(file_text)
568
+
569
+ pii_filter = PIIFilter()
570
+ with tracer.start_as_current_span("sanitize_document"):
571
+ pii_filtered_file_text = pii_filter.sanitize(
572
+ injection_filtered_file_text, app.state.language_detector
573
+ )
574
 
575
  if session_document_store.create_document(
576
+ session_id, pii_filtered_file_text, file_name, file_size
577
  ):
578
  session_tracker.update_session(session_id)
579
  else:
requirements.txt CHANGED
@@ -141,4 +141,5 @@ opentelemetry-sdk==1.39.1
141
  opentelemetry-instrumentation-fastapi==0.60b1
142
  opentelemetry-instrumentation-httpx==0.60b1
143
  slowapi==0.1.9
144
- psutil==7.2.2
 
 
141
  opentelemetry-instrumentation-fastapi==0.60b1
142
  opentelemetry-instrumentation-httpx==0.60b1
143
  slowapi==0.1.9
144
+ psutil==7.2.2
145
+ # lingua-language-detector==2.1.1
telemetry.py CHANGED
@@ -18,6 +18,7 @@ class FilteredConsoleExporter(SpanExporter):
18
  "PromptSanitizer",
19
  "sanitize docs_content",
20
  "sanitize retrieval_query",
 
21
  }
22
 
23
  def export(self, spans):
 
18
  "PromptSanitizer",
19
  "sanitize docs_content",
20
  "sanitize retrieval_query",
21
+ "sanitize_document",
22
  }
23
 
24
  def export(self, spans):
templates/index.html CHANGED
@@ -56,7 +56,7 @@
56
  <div class="modal-content slide language-modal">
57
  <div class="content-top">
58
  <h2 data-i18n="choose_language_title"></h2>
59
- <p data-i18n="change_language_instructions"></p>
60
  </div>
61
 
62
  <div class="form-group">
@@ -196,7 +196,7 @@
196
  </div>
197
  <h3 data-i18n="file_add_title"></h3>
198
  <div id="file-drop-zone" class="file-drop-area">
199
- <p><span data-i18n="file_add_instructions_prefix"></span><a href="#" data-i18n="click"></a><span data-i18n="file_add_instructions_suffix"></span>
200
  <input
201
  type="file"
202
  id="file-input"
 
56
  <div class="modal-content slide language-modal">
57
  <div class="content-top">
58
  <h2 data-i18n="choose_language_title"></h2>
59
+ <p style="text-align: justify;" data-i18n="change_language_instructions"></p>
60
  </div>
61
 
62
  <div class="form-group">
 
196
  </div>
197
  <h3 data-i18n="file_add_title"></h3>
198
  <div id="file-drop-zone" class="file-drop-area">
199
+ <p><span data-i18n="file_add_instructions_prefix"></span><a href="#" data-i18n="click"></a><span data-i18n="file_add_instructions_suffix"></span></p>
200
  <input
201
  type="file"
202
  id="file-input"