# utils.py import os import re from io import BytesIO from typing import List, Tuple, Dict, Optional from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import numpy as np from pypdf import PdfReader import docx from tqdm.auto import tqdm # Vector store compatibility imports from qdrant_client import QdrantClient from qdrant_client.http.models import VectorParams, Distance import faiss import uuid import pickle # ------------------------- # Document parsing # ------------------------- def extract_text_from_pdf(file_bytes: bytes) -> str: reader = PdfReader(BytesIO(file_bytes)) texts = [] for page in reader.pages: try: texts.append(page.extract_text() or "") except Exception: texts.append("") return "\n".join(texts) def extract_text_from_docx(file_bytes: bytes) -> str: f = BytesIO(file_bytes) doc = docx.Document(f) paragraphs = [p.text for p in doc.paragraphs] return "\n".join(paragraphs) def extract_text(filename: str, bytestr: bytes) -> str: ext = filename.lower().split('.')[-1] if ext == "pdf": return extract_text_from_pdf(bytestr) elif ext in ("docx", "doc"): return extract_text_from_docx(bytestr) else: raise ValueError(f"Unsupported file type: {ext}") # ------------------------- # Chunking (simple char-based chunks with overlap) # ------------------------- def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: if not text: return [] text = re.sub(r'\n\s*\n', '\n', text) # collapse multiple blank lines start = 0 chunks = [] L = len(text) while start < L: end = start + chunk_size chunk = text[start:end] chunks.append(chunk.strip()) start = end - overlap if start < 0: start = 0 return chunks # ------------------------- # Embeddings (SentenceTransformer) # ------------------------- EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") _embed_model = None def load_embedding_model(): global _embed_model if _embed_model is None: _embed_model = SentenceTransformer(EMBED_MODEL_NAME) return _embed_model def embed_texts(texts: List[str]) -> np.ndarray: model = load_embedding_model() embeddings = model.encode(texts, show_progress_bar=False, convert_to_numpy=True) return embeddings # ------------------------- # Generator model (RAG prompt -> generate answer) # ------------------------- # Use a lightweight seq2seq model that runs reasonably on CPU for small questions. GEN_MODEL_NAME = os.environ.get("GEN_MODEL", "google/flan-t5-small") _gen_pipeline = None def load_generator(): global _gen_pipeline if _gen_pipeline is None: # Use Seq2SeqPipeline _gen_pipeline = pipeline("text2text-generation", model=GEN_MODEL_NAME, tokenizer=GEN_MODEL_NAME, device=-1) return _gen_pipeline def generate_answer(prompt: str, max_length: int = 256) -> str: gen = load_generator() out = gen(prompt, max_length=max_length, do_sample=False) return out[0]["generated_text"] # ------------------------- # Vector store wrapper: Qdrant (preferred) or FAISS (fallback) # ------------------------- class VectorStore: def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]): raise NotImplementedError() def query(self, embedding: np.ndarray, top_k: int = 5) -> List[Tuple[str, float, str, dict]]: """Return list of (id, score, text, metadata)""" raise NotImplementedError() def persist(self, path: str): pass # Qdrant store class QdrantStore(VectorStore): def __init__(self, collection_name="docs", host=None, port=None, prefer_grpc=False): # host expected like "http://localhost:6333" or host + port q_host = os.environ.get("QDRANT_URL") or host api_key = os.environ.get("QDRANT_API_KEY") if q_host: # if full url provided, QdrantClient accepts url param if q_host.startswith("http"): self.client = QdrantClient(url=q_host, api_key=api_key) else: # assume host & port separated self.client = QdrantClient(host=q_host, port=port or 6333, api_key=api_key) else: raise ValueError("Qdrant URL not provided for QdrantStore") self.collection_name = collection_name # ensure collection exists try: self.client.recreate_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=384, distance=Distance.COSINE) # 384 for MiniLM; adjust if using different embed dim ) except Exception: # maybe already exists; ignore pass def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]): points = [] for i, uid in enumerate(ids): points.append({"id": uid, "vector": embeddings[i].tolist(), "payload": {"meta": metadatas[i], "text": texts[i]}}) self.client.upsert(collection_name=self.collection_name, points=points) def query(self, embedding: np.ndarray, top_k: int = 5): hits = self.client.search(collection_name=self.collection_name, query_vector=embedding.tolist(), limit=top_k) results = [] for h in hits: metadata = h.payload.get("meta", {}) text = h.payload.get("text", "") results.append((str(h.id), float(h.score), text, metadata)) return results # FAISS fallback (in-memory) class FAISSStore(VectorStore): def __init__(self, dim: int = 384): self.dim = dim self.index = faiss.IndexFlatIP(dim) # inner product (we will normalize) self.texts = [] self.metadatas = [] self.ids = [] def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]): # normalize embeddings for cosine via inner product norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms[norms==0] = 1.0 emb_norm = embeddings / norms self.index.add(emb_norm.astype('float32')) self.texts.extend(texts) self.metadatas.extend(metadatas) self.ids.extend(ids) def query(self, embedding: np.ndarray, top_k: int = 5): emb = embedding.reshape(1, -1) norm = np.linalg.norm(emb) if norm == 0: norm = 1.0 emb = emb / norm D, I = self.index.search(emb.astype('float32'), k=top_k) results = [] for score, idx in zip(D[0], I[0]): if idx < 0 or idx >= len(self.texts): continue results.append((self.ids[idx], float(score), self.texts[idx], self.metadatas[idx])) return results # Utility to create appropriate store def get_vector_store(prefer_qdrant=True, qdrant_collection="docs", embed_dim=384): qdrant_url = os.environ.get("QDRANT_URL") if prefer_qdrant and qdrant_url: try: return QdrantStore(collection_name=qdrant_collection) except Exception as e: print("Qdrant connection failed; falling back to FAISS. Error:", e) # fallback return FAISSStore(dim=embed_dim) # ------------------------- # Building knowledge base: takes document text, chunks, embeds, and stores; returns ids # ------------------------- def build_doc_store(text: str, store: VectorStore, chunk_size=1000, overlap=200, source_name="uploaded_doc"): chunks = chunk_text(text, chunk_size=chunk_size, overlap=overlap) if not chunks: return [] embeddings = embed_texts(chunks) ids = [str(uuid.uuid4()) for _ in chunks] metadatas = [{"source": source_name, "chunk_index": i} for i in range(len(chunks))] store.add(ids=ids, embeddings=embeddings, metadatas=metadatas, texts=chunks) return [{"id": _id, "text": t, "metadata": m} for _id, t, m in zip(ids, chunks, metadatas)]