LeonardoMdSA commited on
Commit
df70a57
·
1 Parent(s): 9bfdb2a

HF push again

Browse files
multi_doc_chat/config/config.yaml DELETED
@@ -1,23 +0,0 @@
1
- embedding_model:
2
- provider: "google"
3
- model_name: "models/text-embedding-004"
4
-
5
- retriever:
6
- top_k: 10
7
- search_type: "mmr" # Options: "similarity", "mmr", "similarity_score_threshold"
8
- # MMR (Maximal Marginal Relevance) parameters for diverse results
9
- fetch_k: 20 # Number of documents to fetch before MMR re-ranking (should be > top_k)
10
- lambda_mult: 0.5 # Diversity vs relevance (0=max diversity, 1=max relevance)
11
-
12
- llm:
13
- groq:
14
- provider: "groq"
15
- model_name: "openai/gpt-oss-20b"
16
- temperature: 0
17
- max_output_tokens: 2048
18
-
19
- google:
20
- provider: "google"
21
- model_name: "gemini-2.0-flash"
22
- temperature: 0
23
- max_output_tokens: 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_doc_chat/exception/__init__.py DELETED
File without changes
multi_doc_chat/exception/custom_exception.py DELETED
@@ -1,53 +0,0 @@
1
- import sys
2
- import traceback
3
- from typing import Optional, cast
4
-
5
-
6
- class DocumentPortalException(Exception):
7
- def __init__(self, error_message, error_details: Optional[object] = None):
8
- # Normalize message
9
- if isinstance(error_message, BaseException):
10
- norm_msg = str(error_message)
11
- else:
12
- norm_msg = str(error_message)
13
-
14
- # Resolve exc_info (supports: sys module, Exception object, or current context)
15
- exc_type = exc_value = exc_tb = None
16
- if error_details is None:
17
- exc_type, exc_value, exc_tb = sys.exc_info()
18
- else:
19
- if hasattr(error_details, "exc_info"): # e.g., sys
20
- #exc_type, exc_value, exc_tb = error_details.exc_info()
21
- exc_info_obj = cast(sys, error_details)
22
- exc_type, exc_value, exc_tb = exc_info_obj.exc_info()
23
- elif isinstance(error_details, BaseException):
24
- exc_type, exc_value, exc_tb = type(error_details), error_details, error_details.__traceback__
25
- else:
26
- exc_type, exc_value, exc_tb = sys.exc_info()
27
-
28
- # Walk to the last frame to report the most relevant location
29
- last_tb = exc_tb
30
- while last_tb and last_tb.tb_next:
31
- last_tb = last_tb.tb_next
32
-
33
- self.file_name = last_tb.tb_frame.f_code.co_filename if last_tb else "<unknown>"
34
- self.lineno = last_tb.tb_lineno if last_tb else -1
35
- self.error_message = norm_msg
36
-
37
- # Full pretty traceback (if available)
38
- if exc_type and exc_tb:
39
- self.traceback_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
40
- else:
41
- self.traceback_str = ""
42
-
43
- super().__init__(self.__str__())
44
-
45
- def __str__(self):
46
- # Compact, logger-friendly message (no leading spaces)
47
- base = f"Error in [{self.file_name}] at line [{self.lineno}] | Message: {self.error_message}"
48
- if self.traceback_str:
49
- return f"{base}\nTraceback:\n{self.traceback_str}"
50
- return base
51
-
52
- def __repr__(self):
53
- return f"DocumentPortalException(file={self.file_name!r}, line={self.lineno}, message={self.error_message!r})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_doc_chat/model/__init__.py DELETED
File without changes
multi_doc_chat/model/models.py DELETED
@@ -1,29 +0,0 @@
1
- from pydantic import BaseModel, Field
2
- from typing import Annotated
3
- from enum import Enum
4
-
5
-
6
-
7
- class ChatAnswer(BaseModel):
8
- """Validate chat answer type and length."""
9
- answer: Annotated[str, Field(min_length=1, max_length=4096)]
10
-
11
-
12
- class PromptType(str, Enum):
13
- CONTEXTUALIZE_QUESTION = "contextualize_question"
14
- CONTEXT_QA = "context_qa"
15
-
16
-
17
- class UploadResponse(BaseModel):
18
- session_id: str
19
- indexed: bool
20
- message: str | None = None
21
-
22
-
23
- class ChatRequest(BaseModel):
24
- session_id: str
25
- message: str
26
-
27
-
28
- class ChatResponse(BaseModel):
29
- answer: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_doc_chat/src/__init__.py DELETED
File without changes
multi_doc_chat/src/document_chat/__init__.py DELETED
File without changes
multi_doc_chat/src/document_chat/retrieval.py DELETED
@@ -1,197 +0,0 @@
1
- import sys
2
- import os
3
- from operator import itemgetter
4
- from typing import List, Optional, Dict, Any
5
-
6
- from langchain_core.messages import BaseMessage
7
- from langchain_core.output_parsers import StrOutputParser
8
- from langchain_core.prompts import ChatPromptTemplate
9
- from langchain_community.vectorstores import FAISS
10
-
11
- from multi_doc_chat.utils.model_loader import ModelLoader
12
- from multi_doc_chat.exception.custom_exception import DocumentPortalException
13
- from multi_doc_chat.logger import GLOBAL_LOGGER as log
14
- from multi_doc_chat.prompts.prompt_library import PROMPT_REGISTRY
15
- from multi_doc_chat.model.models import PromptType, ChatAnswer
16
- from pydantic import ValidationError
17
-
18
-
19
- class ConversationalRAG:
20
- """
21
- LCEL-based Conversational RAG with lazy retriever initialization.
22
-
23
- Usage:
24
- rag = ConversationalRAG(session_id="abc")
25
- rag.load_retriever_from_faiss(index_path="faiss_index/abc", k=5, index_name="index")
26
- answer = rag.invoke("What is ...?", chat_history=[])
27
- """
28
-
29
- def __init__(self, session_id: Optional[str], retriever=None):
30
- try:
31
- self.session_id = session_id
32
-
33
- # Load LLM and prompts once
34
- self.llm = self._load_llm()
35
- self.contextualize_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
36
- PromptType.CONTEXTUALIZE_QUESTION.value
37
- ]
38
- self.qa_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
39
- PromptType.CONTEXT_QA.value
40
- ]
41
-
42
- # Lazy pieces
43
- self.retriever = retriever
44
- self.chain = None
45
- if self.retriever is not None:
46
- self._build_lcel_chain()
47
-
48
- log.info("ConversationalRAG initialized", session_id=self.session_id)
49
- except Exception as e:
50
- log.error("Failed to initialize ConversationalRAG", error=str(e))
51
- raise DocumentPortalException("Initialization error in ConversationalRAG", sys)
52
-
53
- # ---------- Public API ----------
54
-
55
- def load_retriever_from_faiss(
56
- self,
57
- index_path: str,
58
- k: int = 5,
59
- index_name: str = "index",
60
- search_type: str = "mmr",
61
- fetch_k: int = 20,
62
- lambda_mult: float = 0.5,
63
- search_kwargs: Optional[Dict[str, Any]] = None,
64
- ):
65
- """
66
- Load FAISS vectorstore from disk and build retriever + LCEL chain.
67
-
68
- Args:
69
- index_path: Path to FAISS index directory
70
- k: Number of documents to return
71
- index_name: Name of the index file
72
- search_type: Type of search ("similarity", "mmr", "similarity_score_threshold")
73
- fetch_k: Number of documents to fetch before MMR re-ranking (only for MMR)
74
- lambda_mult: Diversity parameter for MMR (0=max diversity, 1=max relevance)
75
- search_kwargs: Custom search kwargs (overrides other parameters if provided)
76
- """
77
- try:
78
- if not os.path.isdir(index_path):
79
- raise FileNotFoundError(f"FAISS index directory not found: {index_path}")
80
-
81
- embeddings = ModelLoader().load_embeddings()
82
- vectorstore = FAISS.load_local(
83
- index_path,
84
- embeddings,
85
- index_name=index_name,
86
- allow_dangerous_deserialization=True, # ok if you trust the index
87
- )
88
-
89
- if search_kwargs is None:
90
- search_kwargs = {"k": k}
91
- if search_type == "mmr":
92
- search_kwargs["fetch_k"] = fetch_k
93
- search_kwargs["lambda_mult"] = lambda_mult
94
-
95
- self.retriever = vectorstore.as_retriever(
96
- search_type=search_type, search_kwargs=search_kwargs
97
- )
98
- self._build_lcel_chain()
99
-
100
- log.info(
101
- "FAISS retriever loaded successfully",
102
- index_path=index_path,
103
- index_name=index_name,
104
- search_type=search_type,
105
- k=k,
106
- fetch_k=fetch_k if search_type == "mmr" else None,
107
- lambda_mult=lambda_mult if search_type == "mmr" else None,
108
- session_id=self.session_id,
109
- )
110
- return self.retriever
111
-
112
- except Exception as e:
113
- log.error("Failed to load retriever from FAISS", error=str(e))
114
- raise DocumentPortalException("Loading error in ConversationalRAG", sys)
115
-
116
- def invoke(self, user_input: str, chat_history: Optional[List[BaseMessage]] = None) -> str:
117
- """Invoke the LCEL pipeline."""
118
- try:
119
- if self.chain is None:
120
- raise DocumentPortalException(
121
- "RAG chain not initialized. Call load_retriever_from_faiss() before invoke().", sys
122
- )
123
- chat_history = chat_history or []
124
- payload = {"input": user_input, "chat_history": chat_history}
125
- answer = self.chain.invoke(payload)
126
- if not answer:
127
- log.warning(
128
- "No answer generated", user_input=user_input, session_id=self.session_id
129
- )
130
- return "no answer generated."
131
- # Validate answer type and length using Pydantic model
132
- try:
133
- validated = ChatAnswer(answer=str(answer))
134
- answer = validated.answer
135
- except ValidationError as ve:
136
- log.error("Invalid chat answer", error=str(ve))
137
- raise DocumentPortalException("Invalid chat answer", sys)
138
- log.info(
139
- "Chain invoked successfully",
140
- session_id=self.session_id,
141
- user_input=user_input,
142
- answer_preview=str(answer)[:150],
143
- )
144
- return answer
145
- except Exception as e:
146
- log.error("Failed to invoke ConversationalRAG", error=str(e))
147
- raise DocumentPortalException("Invocation error in ConversationalRAG", sys)
148
-
149
- # ---------- Internals ----------
150
-
151
- def _load_llm(self):
152
- try:
153
- llm = ModelLoader().load_llm()
154
- if not llm:
155
- raise ValueError("LLM could not be loaded")
156
- log.info("LLM loaded successfully", session_id=self.session_id)
157
- return llm
158
- except Exception as e:
159
- log.error("Failed to load LLM", error=str(e))
160
- raise DocumentPortalException("LLM loading error in ConversationalRAG", sys)
161
-
162
- @staticmethod
163
- def _format_docs(docs) -> str:
164
- return "\n\n".join(getattr(d, "page_content", str(d)) for d in docs)
165
-
166
- def _build_lcel_chain(self):
167
- try:
168
- if self.retriever is None:
169
- raise DocumentPortalException("No retriever set before building chain", sys)
170
-
171
- # 1) Rewrite user question with chat history context
172
- question_rewriter = (
173
- {"input": itemgetter("input"), "chat_history": itemgetter("chat_history")}
174
- | self.contextualize_prompt
175
- | self.llm
176
- | StrOutputParser()
177
- )
178
-
179
- # 2) Retrieve docs for rewritten question
180
- retrieve_docs = question_rewriter | self.retriever | self._format_docs
181
-
182
- # 3) Answer using retrieved context + original input + chat history
183
- self.chain = (
184
- {
185
- "context": retrieve_docs,
186
- "input": itemgetter("input"),
187
- "chat_history": itemgetter("chat_history"),
188
- }
189
- | self.qa_prompt
190
- | self.llm
191
- | StrOutputParser()
192
- )
193
-
194
- log.info("LCEL graph built successfully", session_id=self.session_id)
195
- except Exception as e:
196
- log.error("Failed to build LCEL chain", error=str(e), session_id=self.session_id)
197
- raise DocumentPortalException("Failed to build LCEL chain", sys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_doc_chat/utils/config_loader.py DELETED
@@ -1,27 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- import yaml
4
-
5
- def _project_root() -> Path:
6
- # .../utils/config_loader.py -> parents[1] == project root
7
- return Path(__file__).resolve().parents[1]
8
-
9
- def load_config(config_path: str | None = None) -> dict:
10
- """
11
- Resolve config path reliably irrespective of CWD.
12
- Priority: explicit arg > CONFIG_PATH env > <project_root>/config/config.yaml
13
- """
14
- env_path = os.getenv("CONFIG_PATH")
15
- if config_path is None:
16
- # _project_root() already points to the package root (multi_doc_chat)
17
- config_path = env_path or str(_project_root() / "config" / "config.yaml")
18
-
19
- path = Path(config_path)
20
- if not path.is_absolute():
21
- path = _project_root() / path
22
-
23
- if not path.exists():
24
- raise FileNotFoundError(f"Config file not found: {path}")
25
-
26
- with open(path, "r", encoding="utf-8") as f:
27
- return yaml.safe_load(f) or {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_doc_chat/utils/file_io.py DELETED
@@ -1,58 +0,0 @@
1
-
2
- from __future__ import annotations
3
- import re
4
- import uuid
5
- from pathlib import Path
6
- from typing import Iterable, List
7
- from multi_doc_chat.logger.cutom_logger import CustomLogger
8
- from multi_doc_chat.exception.custom_exception import DocumentPortalException
9
-
10
- SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".txt", ".pptx", ".md", ".csv", ".xlsx", ".xls", ".db", ".sqlite", ".sqlite3"}
11
-
12
- # Local logger instance
13
- log = CustomLogger().get_logger(__name__)
14
-
15
-
16
- def save_uploaded_files(uploaded_files: Iterable, target_dir: Path) -> List[Path]:
17
- """Save uploaded files (Streamlit-like) and return local paths."""
18
- try:
19
- target_dir.mkdir(parents=True, exist_ok=True)
20
- saved: List[Path] = []
21
- for uf in uploaded_files:
22
- # Handle Starlette UploadFile (has .filename and .file) and generic objects (have .name)
23
- name = getattr(uf, "filename", getattr(uf, "name", "file"))
24
- ext = Path(name).suffix.lower()
25
- if ext not in SUPPORTED_EXTENSIONS:
26
- log.warning("Unsupported file skipped", filename=name)
27
- continue
28
- # Clean file name (only alphanum, dash, underscore)
29
- safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', Path(name).stem).lower()
30
- fname = f"{safe_name}_{uuid.uuid4().hex[:6]}{ext}"
31
- fname = f"{uuid.uuid4().hex[:8]}{ext}"
32
- out = target_dir / fname
33
- with open(out, "wb") as f:
34
- # Prefer underlying file buffer when available (e.g., Starlette UploadFile.file)
35
- if hasattr(uf, "file") and hasattr(uf.file, "read"):
36
- f.write(uf.file.read())
37
- elif hasattr(uf, "read"):
38
- data = uf.read()
39
- # If a memoryview is returned, convert to bytes; otherwise assume bytes
40
- if isinstance(data, memoryview):
41
- data = data.tobytes()
42
- f.write(data)
43
- else:
44
- # Fallback for objects exposing a getbuffer()
45
- buf = getattr(uf, "getbuffer", None)
46
- if callable(buf):
47
- data = buf()
48
- if isinstance(data, memoryview):
49
- data = data.tobytes()
50
- f.write(data)
51
- else:
52
- raise ValueError("Unsupported uploaded file object; no readable interface")
53
- saved.append(out)
54
- log.info("File saved for ingestion", uploaded=name, saved_as=str(out))
55
- return saved
56
- except Exception as e:
57
- log.error("Failed to save uploaded files", error=str(e), dir=str(target_dir))
58
- raise DocumentPortalException("Failed to save uploaded files", e) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates/index.html CHANGED
@@ -3,7 +3,7 @@
3
  <head>
4
  <meta charset="UTF-8" />
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
- <title>MultiDocChat</title>
7
  <link rel="stylesheet" href="/static/styles.css" />
8
  <style>
9
  /* Minimal inline tweaks; most styles live in styles.css */
 
3
  <head>
4
  <meta charset="UTF-8" />
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>RAG Solution</title>
7
  <link rel="stylesheet" href="/static/styles.css" />
8
  <style>
9
  /* Minimal inline tweaks; most styles live in styles.css */