| import os |
| import re |
| import logging |
| from typing import List, Tuple, Optional |
| import faiss |
| from sentence_transformers import SentenceTransformer |
| from PyPDF2 import PdfReader |
| from docx import Document |
| import pytesseract |
| from PIL import Image |
| import io |
| import openpyxl |
| import pandas as pd |
| from duckduckgo_search import DDGS |
| from fastapi import UploadFile |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
| _embedder: Optional[SentenceTransformer] = None |
|
|
| def _get_embedder() -> SentenceTransformer: |
| global _embedder |
| if _embedder is None: |
| logger.info(f"Loading embedding model: {_EMBED_MODEL_NAME}") |
| _embedder = SentenceTransformer(_EMBED_MODEL_NAME) |
| return _embedder |
|
|
| |
| def extract_text(file: UploadFile) -> str: |
| ext = os.path.splitext(file.filename)[1].lower() |
| content = file.file.read() |
| file_bytes = io.BytesIO(content) |
| if ext == ".pdf": |
| try: |
| reader = PdfReader(file_bytes) |
| return "\n".join(page.extract_text() or "" for page in reader.pages) |
| except Exception as e: |
| logger.error(f"PDF extract failed: {e}") |
| return "" |
| elif ext == ".docx": |
| try: |
| doc = Document(file_bytes) |
| return "\n".join(p.text for p in doc.paragraphs if p.text) |
| except Exception as e: |
| logger.error(f"DOCX extract failed: {e}") |
| return "" |
| elif ext in [".xlsx", ".xls"]: |
| try: |
| wb = openpyxl.load_workbook(file_bytes, read_only=True, data_only=True) |
| text = [] |
| for sheet in wb: |
| for row in sheet.iter_rows(values_only=True): |
| text.append(" ".join(str(cell) for cell in row if cell is not None)) |
| return "\n".join(text) |
| except Exception as e: |
| logger.error(f"Excel extract failed: {e}") |
| return "" |
| elif ext == ".csv": |
| try: |
| df = pd.read_csv(file_bytes) |
| return df.to_string() |
| except Exception as e: |
| logger.error(f"CSV extract failed: {e}") |
| return "" |
| elif ext in [".jpg", ".jpeg", ".png", ".gif"]: |
| try: |
| img = Image.open(file_bytes) |
| return pytesseract.image_to_string(img) |
| except Exception as e: |
| logger.error(f"Image OCR failed: {e}") |
| return "" |
| else: |
| try: |
| return content.decode("utf-8", errors="ignore") |
| except Exception as e: |
| logger.error(f"Text extract failed: {e}") |
| return "" |
|
|
| def clean_text(text: str) -> str: |
| t = re.sub(r"[ \t]+", " ", text) |
| t = re.sub(r"\n{3,}", "\n\n", t) |
| return t.strip() |
|
|
| def chunk_text(text: str, max_tokens: int = 400, overlap: int = 50) -> List[str]: |
| text = clean_text(text) |
| if not text: |
| return [] |
| words = text.split() |
| chunks, start = [], 0 |
| while start < len(words): |
| end = min(len(words), start + max_tokens) |
| chunk = " ".join(words[start:end]).strip() |
| if chunk: |
| chunks.append(chunk) |
| if end == len(words): |
| break |
| start = max(0, end - overlap) |
| return chunks |
|
|
| class RagIndex: |
| def __init__(self, index: faiss.IndexFlatIP, dim: int, chunks: List[str]): |
| self.index = index |
| self.dim = dim |
| self.chunks = chunks |
|
|
| def build_faiss_index(chunks: List[str]) -> RagIndex: |
| emb = _get_embedder() |
| vectors = emb.encode(chunks, convert_to_numpy=True, normalize_embeddings=True) |
| dim = vectors.shape[1] |
| index = faiss.IndexFlatIP(dim) |
| index.add(vectors) |
| return RagIndex(index=index, dim=dim, chunks=chunks) |
|
|
| def search(index: RagIndex, query: str, top_k: int = 6) -> List[Tuple[str, float]]: |
| emb = _get_embedder() |
| q = emb.encode([query], convert_to_numpy=True, normalize_embeddings=True) |
| D, I = index.index.search(q, top_k) |
| hits = [] |
| for score, idx in zip(D[0], I[0]): |
| if idx == -1: |
| continue |
| hits.append((index.chunks[idx], float(score))) |
| return hits |
|
|
| def build_context_from_files(files: List[UploadFile], prompt: str, top_k: int = 6) -> str: |
| all_text = [] |
| for file in files: |
| txt = extract_text(file) |
| if txt: |
| all_text.append(txt) |
| file.file.seek(0) |
| big_text = "\n\n".join(all_text) |
| chunks = chunk_text(big_text, max_tokens=450, overlap=80) |
| if not chunks: |
| return "" |
| idx = build_faiss_index(chunks) |
| hits = search(idx, prompt, top_k=top_k) |
| context_sections = [f"[DOC#{i} score={score:.3f}]\n{chunk}" for i, (chunk, score) in enumerate(hits, 1)] |
| return "\n\n".join(context_sections) |
|
|
| |
| def web_search(query: str) -> str: |
| try: |
| with DDGS() as ddgs: |
| results = [r for r in ddgs.text(query, max_results=5)] |
| sections = [f"[WEB#{i}] Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['href']}" for i, r in enumerate(results, 1)] |
| return "\n\n".join(sections) if sections else "No results found." |
| except Exception as e: |
| logger.error(f"Web search failed: {e}") |
| return "Web search error." |