Commit ·
df70a57
1
Parent(s): 9bfdb2a
HF push again
Browse files- multi_doc_chat/config/config.yaml +0 -23
- multi_doc_chat/exception/__init__.py +0 -0
- multi_doc_chat/exception/custom_exception.py +0 -53
- multi_doc_chat/model/__init__.py +0 -0
- multi_doc_chat/model/models.py +0 -29
- multi_doc_chat/src/__init__.py +0 -0
- multi_doc_chat/src/document_chat/__init__.py +0 -0
- multi_doc_chat/src/document_chat/retrieval.py +0 -197
- multi_doc_chat/utils/config_loader.py +0 -27
- multi_doc_chat/utils/file_io.py +0 -58
- templates/index.html +1 -1
multi_doc_chat/config/config.yaml
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
embedding_model:
|
| 2 |
-
provider: "google"
|
| 3 |
-
model_name: "models/text-embedding-004"
|
| 4 |
-
|
| 5 |
-
retriever:
|
| 6 |
-
top_k: 10
|
| 7 |
-
search_type: "mmr" # Options: "similarity", "mmr", "similarity_score_threshold"
|
| 8 |
-
# MMR (Maximal Marginal Relevance) parameters for diverse results
|
| 9 |
-
fetch_k: 20 # Number of documents to fetch before MMR re-ranking (should be > top_k)
|
| 10 |
-
lambda_mult: 0.5 # Diversity vs relevance (0=max diversity, 1=max relevance)
|
| 11 |
-
|
| 12 |
-
llm:
|
| 13 |
-
groq:
|
| 14 |
-
provider: "groq"
|
| 15 |
-
model_name: "openai/gpt-oss-20b"
|
| 16 |
-
temperature: 0
|
| 17 |
-
max_output_tokens: 2048
|
| 18 |
-
|
| 19 |
-
google:
|
| 20 |
-
provider: "google"
|
| 21 |
-
model_name: "gemini-2.0-flash"
|
| 22 |
-
temperature: 0
|
| 23 |
-
max_output_tokens: 2048
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_doc_chat/exception/__init__.py
DELETED
|
File without changes
|
multi_doc_chat/exception/custom_exception.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import traceback
|
| 3 |
-
from typing import Optional, cast
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class DocumentPortalException(Exception):
|
| 7 |
-
def __init__(self, error_message, error_details: Optional[object] = None):
|
| 8 |
-
# Normalize message
|
| 9 |
-
if isinstance(error_message, BaseException):
|
| 10 |
-
norm_msg = str(error_message)
|
| 11 |
-
else:
|
| 12 |
-
norm_msg = str(error_message)
|
| 13 |
-
|
| 14 |
-
# Resolve exc_info (supports: sys module, Exception object, or current context)
|
| 15 |
-
exc_type = exc_value = exc_tb = None
|
| 16 |
-
if error_details is None:
|
| 17 |
-
exc_type, exc_value, exc_tb = sys.exc_info()
|
| 18 |
-
else:
|
| 19 |
-
if hasattr(error_details, "exc_info"): # e.g., sys
|
| 20 |
-
#exc_type, exc_value, exc_tb = error_details.exc_info()
|
| 21 |
-
exc_info_obj = cast(sys, error_details)
|
| 22 |
-
exc_type, exc_value, exc_tb = exc_info_obj.exc_info()
|
| 23 |
-
elif isinstance(error_details, BaseException):
|
| 24 |
-
exc_type, exc_value, exc_tb = type(error_details), error_details, error_details.__traceback__
|
| 25 |
-
else:
|
| 26 |
-
exc_type, exc_value, exc_tb = sys.exc_info()
|
| 27 |
-
|
| 28 |
-
# Walk to the last frame to report the most relevant location
|
| 29 |
-
last_tb = exc_tb
|
| 30 |
-
while last_tb and last_tb.tb_next:
|
| 31 |
-
last_tb = last_tb.tb_next
|
| 32 |
-
|
| 33 |
-
self.file_name = last_tb.tb_frame.f_code.co_filename if last_tb else "<unknown>"
|
| 34 |
-
self.lineno = last_tb.tb_lineno if last_tb else -1
|
| 35 |
-
self.error_message = norm_msg
|
| 36 |
-
|
| 37 |
-
# Full pretty traceback (if available)
|
| 38 |
-
if exc_type and exc_tb:
|
| 39 |
-
self.traceback_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
|
| 40 |
-
else:
|
| 41 |
-
self.traceback_str = ""
|
| 42 |
-
|
| 43 |
-
super().__init__(self.__str__())
|
| 44 |
-
|
| 45 |
-
def __str__(self):
|
| 46 |
-
# Compact, logger-friendly message (no leading spaces)
|
| 47 |
-
base = f"Error in [{self.file_name}] at line [{self.lineno}] | Message: {self.error_message}"
|
| 48 |
-
if self.traceback_str:
|
| 49 |
-
return f"{base}\nTraceback:\n{self.traceback_str}"
|
| 50 |
-
return base
|
| 51 |
-
|
| 52 |
-
def __repr__(self):
|
| 53 |
-
return f"DocumentPortalException(file={self.file_name!r}, line={self.lineno}, message={self.error_message!r})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_doc_chat/model/__init__.py
DELETED
|
File without changes
|
multi_doc_chat/model/models.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel, Field
|
| 2 |
-
from typing import Annotated
|
| 3 |
-
from enum import Enum
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class ChatAnswer(BaseModel):
|
| 8 |
-
"""Validate chat answer type and length."""
|
| 9 |
-
answer: Annotated[str, Field(min_length=1, max_length=4096)]
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class PromptType(str, Enum):
|
| 13 |
-
CONTEXTUALIZE_QUESTION = "contextualize_question"
|
| 14 |
-
CONTEXT_QA = "context_qa"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class UploadResponse(BaseModel):
|
| 18 |
-
session_id: str
|
| 19 |
-
indexed: bool
|
| 20 |
-
message: str | None = None
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ChatRequest(BaseModel):
|
| 24 |
-
session_id: str
|
| 25 |
-
message: str
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class ChatResponse(BaseModel):
|
| 29 |
-
answer: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_doc_chat/src/__init__.py
DELETED
|
File without changes
|
multi_doc_chat/src/document_chat/__init__.py
DELETED
|
File without changes
|
multi_doc_chat/src/document_chat/retrieval.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import os
|
| 3 |
-
from operator import itemgetter
|
| 4 |
-
from typing import List, Optional, Dict, Any
|
| 5 |
-
|
| 6 |
-
from langchain_core.messages import BaseMessage
|
| 7 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
-
from langchain_community.vectorstores import FAISS
|
| 10 |
-
|
| 11 |
-
from multi_doc_chat.utils.model_loader import ModelLoader
|
| 12 |
-
from multi_doc_chat.exception.custom_exception import DocumentPortalException
|
| 13 |
-
from multi_doc_chat.logger import GLOBAL_LOGGER as log
|
| 14 |
-
from multi_doc_chat.prompts.prompt_library import PROMPT_REGISTRY
|
| 15 |
-
from multi_doc_chat.model.models import PromptType, ChatAnswer
|
| 16 |
-
from pydantic import ValidationError
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ConversationalRAG:
|
| 20 |
-
"""
|
| 21 |
-
LCEL-based Conversational RAG with lazy retriever initialization.
|
| 22 |
-
|
| 23 |
-
Usage:
|
| 24 |
-
rag = ConversationalRAG(session_id="abc")
|
| 25 |
-
rag.load_retriever_from_faiss(index_path="faiss_index/abc", k=5, index_name="index")
|
| 26 |
-
answer = rag.invoke("What is ...?", chat_history=[])
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
def __init__(self, session_id: Optional[str], retriever=None):
|
| 30 |
-
try:
|
| 31 |
-
self.session_id = session_id
|
| 32 |
-
|
| 33 |
-
# Load LLM and prompts once
|
| 34 |
-
self.llm = self._load_llm()
|
| 35 |
-
self.contextualize_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
|
| 36 |
-
PromptType.CONTEXTUALIZE_QUESTION.value
|
| 37 |
-
]
|
| 38 |
-
self.qa_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
|
| 39 |
-
PromptType.CONTEXT_QA.value
|
| 40 |
-
]
|
| 41 |
-
|
| 42 |
-
# Lazy pieces
|
| 43 |
-
self.retriever = retriever
|
| 44 |
-
self.chain = None
|
| 45 |
-
if self.retriever is not None:
|
| 46 |
-
self._build_lcel_chain()
|
| 47 |
-
|
| 48 |
-
log.info("ConversationalRAG initialized", session_id=self.session_id)
|
| 49 |
-
except Exception as e:
|
| 50 |
-
log.error("Failed to initialize ConversationalRAG", error=str(e))
|
| 51 |
-
raise DocumentPortalException("Initialization error in ConversationalRAG", sys)
|
| 52 |
-
|
| 53 |
-
# ---------- Public API ----------
|
| 54 |
-
|
| 55 |
-
def load_retriever_from_faiss(
|
| 56 |
-
self,
|
| 57 |
-
index_path: str,
|
| 58 |
-
k: int = 5,
|
| 59 |
-
index_name: str = "index",
|
| 60 |
-
search_type: str = "mmr",
|
| 61 |
-
fetch_k: int = 20,
|
| 62 |
-
lambda_mult: float = 0.5,
|
| 63 |
-
search_kwargs: Optional[Dict[str, Any]] = None,
|
| 64 |
-
):
|
| 65 |
-
"""
|
| 66 |
-
Load FAISS vectorstore from disk and build retriever + LCEL chain.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
index_path: Path to FAISS index directory
|
| 70 |
-
k: Number of documents to return
|
| 71 |
-
index_name: Name of the index file
|
| 72 |
-
search_type: Type of search ("similarity", "mmr", "similarity_score_threshold")
|
| 73 |
-
fetch_k: Number of documents to fetch before MMR re-ranking (only for MMR)
|
| 74 |
-
lambda_mult: Diversity parameter for MMR (0=max diversity, 1=max relevance)
|
| 75 |
-
search_kwargs: Custom search kwargs (overrides other parameters if provided)
|
| 76 |
-
"""
|
| 77 |
-
try:
|
| 78 |
-
if not os.path.isdir(index_path):
|
| 79 |
-
raise FileNotFoundError(f"FAISS index directory not found: {index_path}")
|
| 80 |
-
|
| 81 |
-
embeddings = ModelLoader().load_embeddings()
|
| 82 |
-
vectorstore = FAISS.load_local(
|
| 83 |
-
index_path,
|
| 84 |
-
embeddings,
|
| 85 |
-
index_name=index_name,
|
| 86 |
-
allow_dangerous_deserialization=True, # ok if you trust the index
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
if search_kwargs is None:
|
| 90 |
-
search_kwargs = {"k": k}
|
| 91 |
-
if search_type == "mmr":
|
| 92 |
-
search_kwargs["fetch_k"] = fetch_k
|
| 93 |
-
search_kwargs["lambda_mult"] = lambda_mult
|
| 94 |
-
|
| 95 |
-
self.retriever = vectorstore.as_retriever(
|
| 96 |
-
search_type=search_type, search_kwargs=search_kwargs
|
| 97 |
-
)
|
| 98 |
-
self._build_lcel_chain()
|
| 99 |
-
|
| 100 |
-
log.info(
|
| 101 |
-
"FAISS retriever loaded successfully",
|
| 102 |
-
index_path=index_path,
|
| 103 |
-
index_name=index_name,
|
| 104 |
-
search_type=search_type,
|
| 105 |
-
k=k,
|
| 106 |
-
fetch_k=fetch_k if search_type == "mmr" else None,
|
| 107 |
-
lambda_mult=lambda_mult if search_type == "mmr" else None,
|
| 108 |
-
session_id=self.session_id,
|
| 109 |
-
)
|
| 110 |
-
return self.retriever
|
| 111 |
-
|
| 112 |
-
except Exception as e:
|
| 113 |
-
log.error("Failed to load retriever from FAISS", error=str(e))
|
| 114 |
-
raise DocumentPortalException("Loading error in ConversationalRAG", sys)
|
| 115 |
-
|
| 116 |
-
def invoke(self, user_input: str, chat_history: Optional[List[BaseMessage]] = None) -> str:
|
| 117 |
-
"""Invoke the LCEL pipeline."""
|
| 118 |
-
try:
|
| 119 |
-
if self.chain is None:
|
| 120 |
-
raise DocumentPortalException(
|
| 121 |
-
"RAG chain not initialized. Call load_retriever_from_faiss() before invoke().", sys
|
| 122 |
-
)
|
| 123 |
-
chat_history = chat_history or []
|
| 124 |
-
payload = {"input": user_input, "chat_history": chat_history}
|
| 125 |
-
answer = self.chain.invoke(payload)
|
| 126 |
-
if not answer:
|
| 127 |
-
log.warning(
|
| 128 |
-
"No answer generated", user_input=user_input, session_id=self.session_id
|
| 129 |
-
)
|
| 130 |
-
return "no answer generated."
|
| 131 |
-
# Validate answer type and length using Pydantic model
|
| 132 |
-
try:
|
| 133 |
-
validated = ChatAnswer(answer=str(answer))
|
| 134 |
-
answer = validated.answer
|
| 135 |
-
except ValidationError as ve:
|
| 136 |
-
log.error("Invalid chat answer", error=str(ve))
|
| 137 |
-
raise DocumentPortalException("Invalid chat answer", sys)
|
| 138 |
-
log.info(
|
| 139 |
-
"Chain invoked successfully",
|
| 140 |
-
session_id=self.session_id,
|
| 141 |
-
user_input=user_input,
|
| 142 |
-
answer_preview=str(answer)[:150],
|
| 143 |
-
)
|
| 144 |
-
return answer
|
| 145 |
-
except Exception as e:
|
| 146 |
-
log.error("Failed to invoke ConversationalRAG", error=str(e))
|
| 147 |
-
raise DocumentPortalException("Invocation error in ConversationalRAG", sys)
|
| 148 |
-
|
| 149 |
-
# ---------- Internals ----------
|
| 150 |
-
|
| 151 |
-
def _load_llm(self):
|
| 152 |
-
try:
|
| 153 |
-
llm = ModelLoader().load_llm()
|
| 154 |
-
if not llm:
|
| 155 |
-
raise ValueError("LLM could not be loaded")
|
| 156 |
-
log.info("LLM loaded successfully", session_id=self.session_id)
|
| 157 |
-
return llm
|
| 158 |
-
except Exception as e:
|
| 159 |
-
log.error("Failed to load LLM", error=str(e))
|
| 160 |
-
raise DocumentPortalException("LLM loading error in ConversationalRAG", sys)
|
| 161 |
-
|
| 162 |
-
@staticmethod
|
| 163 |
-
def _format_docs(docs) -> str:
|
| 164 |
-
return "\n\n".join(getattr(d, "page_content", str(d)) for d in docs)
|
| 165 |
-
|
| 166 |
-
def _build_lcel_chain(self):
|
| 167 |
-
try:
|
| 168 |
-
if self.retriever is None:
|
| 169 |
-
raise DocumentPortalException("No retriever set before building chain", sys)
|
| 170 |
-
|
| 171 |
-
# 1) Rewrite user question with chat history context
|
| 172 |
-
question_rewriter = (
|
| 173 |
-
{"input": itemgetter("input"), "chat_history": itemgetter("chat_history")}
|
| 174 |
-
| self.contextualize_prompt
|
| 175 |
-
| self.llm
|
| 176 |
-
| StrOutputParser()
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# 2) Retrieve docs for rewritten question
|
| 180 |
-
retrieve_docs = question_rewriter | self.retriever | self._format_docs
|
| 181 |
-
|
| 182 |
-
# 3) Answer using retrieved context + original input + chat history
|
| 183 |
-
self.chain = (
|
| 184 |
-
{
|
| 185 |
-
"context": retrieve_docs,
|
| 186 |
-
"input": itemgetter("input"),
|
| 187 |
-
"chat_history": itemgetter("chat_history"),
|
| 188 |
-
}
|
| 189 |
-
| self.qa_prompt
|
| 190 |
-
| self.llm
|
| 191 |
-
| StrOutputParser()
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
log.info("LCEL graph built successfully", session_id=self.session_id)
|
| 195 |
-
except Exception as e:
|
| 196 |
-
log.error("Failed to build LCEL chain", error=str(e), session_id=self.session_id)
|
| 197 |
-
raise DocumentPortalException("Failed to build LCEL chain", sys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_doc_chat/utils/config_loader.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
import os
|
| 3 |
-
import yaml
|
| 4 |
-
|
| 5 |
-
def _project_root() -> Path:
|
| 6 |
-
# .../utils/config_loader.py -> parents[1] == project root
|
| 7 |
-
return Path(__file__).resolve().parents[1]
|
| 8 |
-
|
| 9 |
-
def load_config(config_path: str | None = None) -> dict:
|
| 10 |
-
"""
|
| 11 |
-
Resolve config path reliably irrespective of CWD.
|
| 12 |
-
Priority: explicit arg > CONFIG_PATH env > <project_root>/config/config.yaml
|
| 13 |
-
"""
|
| 14 |
-
env_path = os.getenv("CONFIG_PATH")
|
| 15 |
-
if config_path is None:
|
| 16 |
-
# _project_root() already points to the package root (multi_doc_chat)
|
| 17 |
-
config_path = env_path or str(_project_root() / "config" / "config.yaml")
|
| 18 |
-
|
| 19 |
-
path = Path(config_path)
|
| 20 |
-
if not path.is_absolute():
|
| 21 |
-
path = _project_root() / path
|
| 22 |
-
|
| 23 |
-
if not path.exists():
|
| 24 |
-
raise FileNotFoundError(f"Config file not found: {path}")
|
| 25 |
-
|
| 26 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 27 |
-
return yaml.safe_load(f) or {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_doc_chat/utils/file_io.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
import re
|
| 4 |
-
import uuid
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Iterable, List
|
| 7 |
-
from multi_doc_chat.logger.cutom_logger import CustomLogger
|
| 8 |
-
from multi_doc_chat.exception.custom_exception import DocumentPortalException
|
| 9 |
-
|
| 10 |
-
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".txt", ".pptx", ".md", ".csv", ".xlsx", ".xls", ".db", ".sqlite", ".sqlite3"}
|
| 11 |
-
|
| 12 |
-
# Local logger instance
|
| 13 |
-
log = CustomLogger().get_logger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def save_uploaded_files(uploaded_files: Iterable, target_dir: Path) -> List[Path]:
|
| 17 |
-
"""Save uploaded files (Streamlit-like) and return local paths."""
|
| 18 |
-
try:
|
| 19 |
-
target_dir.mkdir(parents=True, exist_ok=True)
|
| 20 |
-
saved: List[Path] = []
|
| 21 |
-
for uf in uploaded_files:
|
| 22 |
-
# Handle Starlette UploadFile (has .filename and .file) and generic objects (have .name)
|
| 23 |
-
name = getattr(uf, "filename", getattr(uf, "name", "file"))
|
| 24 |
-
ext = Path(name).suffix.lower()
|
| 25 |
-
if ext not in SUPPORTED_EXTENSIONS:
|
| 26 |
-
log.warning("Unsupported file skipped", filename=name)
|
| 27 |
-
continue
|
| 28 |
-
# Clean file name (only alphanum, dash, underscore)
|
| 29 |
-
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', Path(name).stem).lower()
|
| 30 |
-
fname = f"{safe_name}_{uuid.uuid4().hex[:6]}{ext}"
|
| 31 |
-
fname = f"{uuid.uuid4().hex[:8]}{ext}"
|
| 32 |
-
out = target_dir / fname
|
| 33 |
-
with open(out, "wb") as f:
|
| 34 |
-
# Prefer underlying file buffer when available (e.g., Starlette UploadFile.file)
|
| 35 |
-
if hasattr(uf, "file") and hasattr(uf.file, "read"):
|
| 36 |
-
f.write(uf.file.read())
|
| 37 |
-
elif hasattr(uf, "read"):
|
| 38 |
-
data = uf.read()
|
| 39 |
-
# If a memoryview is returned, convert to bytes; otherwise assume bytes
|
| 40 |
-
if isinstance(data, memoryview):
|
| 41 |
-
data = data.tobytes()
|
| 42 |
-
f.write(data)
|
| 43 |
-
else:
|
| 44 |
-
# Fallback for objects exposing a getbuffer()
|
| 45 |
-
buf = getattr(uf, "getbuffer", None)
|
| 46 |
-
if callable(buf):
|
| 47 |
-
data = buf()
|
| 48 |
-
if isinstance(data, memoryview):
|
| 49 |
-
data = data.tobytes()
|
| 50 |
-
f.write(data)
|
| 51 |
-
else:
|
| 52 |
-
raise ValueError("Unsupported uploaded file object; no readable interface")
|
| 53 |
-
saved.append(out)
|
| 54 |
-
log.info("File saved for ingestion", uploaded=name, saved_as=str(out))
|
| 55 |
-
return saved
|
| 56 |
-
except Exception as e:
|
| 57 |
-
log.error("Failed to save uploaded files", error=str(e), dir=str(target_dir))
|
| 58 |
-
raise DocumentPortalException("Failed to save uploaded files", e) from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates/index.html
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 6 |
-
<title>
|
| 7 |
<link rel="stylesheet" href="/static/styles.css" />
|
| 8 |
<style>
|
| 9 |
/* Minimal inline tweaks; most styles live in styles.css */
|
|
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 6 |
+
<title>RAG Solution</title>
|
| 7 |
<link rel="stylesheet" href="/static/styles.css" />
|
| 8 |
<style>
|
| 9 |
/* Minimal inline tweaks; most styles live in styles.css */
|