backend / app /rag /rag_processor.py
precison9's picture
deploy FastAPI backend
62a1756
import os
import re
import logging
from typing import List, Tuple, Optional
import faiss
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from docx import Document
import pytesseract
from PIL import Image
import io
import openpyxl
import pandas as pd
from duckduckgo_search import DDGS
from fastapi import UploadFile
logger = logging.getLogger(__name__)
_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
_embedder: Optional[SentenceTransformer] = None
def _get_embedder() -> SentenceTransformer:
global _embedder
if _embedder is None:
logger.info(f"Loading embedding model: {_EMBED_MODEL_NAME}")
_embedder = SentenceTransformer(_EMBED_MODEL_NAME)
return _embedder
# Enhanced File Extraction
def extract_text(file: UploadFile) -> str:
ext = os.path.splitext(file.filename)[1].lower()
content = file.file.read()
file_bytes = io.BytesIO(content)
if ext == ".pdf":
try:
reader = PdfReader(file_bytes)
return "\n".join(page.extract_text() or "" for page in reader.pages)
except Exception as e:
logger.error(f"PDF extract failed: {e}")
return ""
elif ext == ".docx":
try:
doc = Document(file_bytes)
return "\n".join(p.text for p in doc.paragraphs if p.text)
except Exception as e:
logger.error(f"DOCX extract failed: {e}")
return ""
elif ext in [".xlsx", ".xls"]:
try:
wb = openpyxl.load_workbook(file_bytes, read_only=True, data_only=True)
text = []
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
text.append(" ".join(str(cell) for cell in row if cell is not None))
return "\n".join(text)
except Exception as e:
logger.error(f"Excel extract failed: {e}")
return ""
elif ext == ".csv":
try:
df = pd.read_csv(file_bytes)
return df.to_string()
except Exception as e:
logger.error(f"CSV extract failed: {e}")
return ""
elif ext in [".jpg", ".jpeg", ".png", ".gif"]: # OCR for images
try:
img = Image.open(file_bytes)
return pytesseract.image_to_string(img)
except Exception as e:
logger.error(f"Image OCR failed: {e}")
return ""
else: # Fallback text
try:
return content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"Text extract failed: {e}")
return ""
def clean_text(text: str) -> str:
t = re.sub(r"[ \t]+", " ", text)
t = re.sub(r"\n{3,}", "\n\n", t)
return t.strip()
def chunk_text(text: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
text = clean_text(text)
if not text:
return []
words = text.split()
chunks, start = [], 0
while start < len(words):
end = min(len(words), start + max_tokens)
chunk = " ".join(words[start:end]).strip()
if chunk:
chunks.append(chunk)
if end == len(words):
break
start = max(0, end - overlap)
return chunks
class RagIndex:
def __init__(self, index: faiss.IndexFlatIP, dim: int, chunks: List[str]):
self.index = index
self.dim = dim
self.chunks = chunks
def build_faiss_index(chunks: List[str]) -> RagIndex:
emb = _get_embedder()
vectors = emb.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
dim = vectors.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(vectors)
return RagIndex(index=index, dim=dim, chunks=chunks)
def search(index: RagIndex, query: str, top_k: int = 6) -> List[Tuple[str, float]]:
emb = _get_embedder()
q = emb.encode([query], convert_to_numpy=True, normalize_embeddings=True)
D, I = index.index.search(q, top_k)
hits = []
for score, idx in zip(D[0], I[0]):
if idx == -1:
continue
hits.append((index.chunks[idx], float(score)))
return hits
def build_context_from_files(files: List[UploadFile], prompt: str, top_k: int = 6) -> str:
all_text = []
for file in files:
txt = extract_text(file)
if txt:
all_text.append(txt)
file.file.seek(0) # Reset
big_text = "\n\n".join(all_text)
chunks = chunk_text(big_text, max_tokens=450, overlap=80)
if not chunks:
return ""
idx = build_faiss_index(chunks)
hits = search(idx, prompt, top_k=top_k)
context_sections = [f"[DOC#{i} score={score:.3f}]\n{chunk}" for i, (chunk, score) in enumerate(hits, 1)]
return "\n\n".join(context_sections)
# Web search tool
def web_search(query: str) -> str:
try:
with DDGS() as ddgs:
results = [r for r in ddgs.text(query, max_results=5)]
sections = [f"[WEB#{i}] Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['href']}" for i, r in enumerate(results, 1)]
return "\n\n".join(sections) if sections else "No results found."
except Exception as e:
logger.error(f"Web search failed: {e}")
return "Web search error."