QA_Retrieval_Bot / utils.py
MansoorSarookh's picture
Create utils.py
211efc1 verified
# utils.py
import os
import re
from io import BytesIO
from typing import List, Tuple, Dict, Optional
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import numpy as np
from pypdf import PdfReader
import docx
from tqdm.auto import tqdm
# Vector store compatibility imports
from qdrant_client import QdrantClient
from qdrant_client.http.models import VectorParams, Distance
import faiss
import uuid
import pickle
# -------------------------
# Document parsing
# -------------------------
def extract_text_from_pdf(file_bytes: bytes) -> str:
reader = PdfReader(BytesIO(file_bytes))
texts = []
for page in reader.pages:
try:
texts.append(page.extract_text() or "")
except Exception:
texts.append("")
return "\n".join(texts)
def extract_text_from_docx(file_bytes: bytes) -> str:
f = BytesIO(file_bytes)
doc = docx.Document(f)
paragraphs = [p.text for p in doc.paragraphs]
return "\n".join(paragraphs)
def extract_text(filename: str, bytestr: bytes) -> str:
ext = filename.lower().split('.')[-1]
if ext == "pdf":
return extract_text_from_pdf(bytestr)
elif ext in ("docx", "doc"):
return extract_text_from_docx(bytestr)
else:
raise ValueError(f"Unsupported file type: {ext}")
# -------------------------
# Chunking (simple char-based chunks with overlap)
# -------------------------
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
if not text:
return []
text = re.sub(r'\n\s*\n', '\n', text) # collapse multiple blank lines
start = 0
chunks = []
L = len(text)
while start < L:
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk.strip())
start = end - overlap
if start < 0:
start = 0
return chunks
# -------------------------
# Embeddings (SentenceTransformer)
# -------------------------
EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
_embed_model = None
def load_embedding_model():
global _embed_model
if _embed_model is None:
_embed_model = SentenceTransformer(EMBED_MODEL_NAME)
return _embed_model
def embed_texts(texts: List[str]) -> np.ndarray:
model = load_embedding_model()
embeddings = model.encode(texts, show_progress_bar=False, convert_to_numpy=True)
return embeddings
# -------------------------
# Generator model (RAG prompt -> generate answer)
# -------------------------
# Use a lightweight seq2seq model that runs reasonably on CPU for small questions.
GEN_MODEL_NAME = os.environ.get("GEN_MODEL", "google/flan-t5-small")
_gen_pipeline = None
def load_generator():
global _gen_pipeline
if _gen_pipeline is None:
# Use Seq2SeqPipeline
_gen_pipeline = pipeline("text2text-generation", model=GEN_MODEL_NAME, tokenizer=GEN_MODEL_NAME, device=-1)
return _gen_pipeline
def generate_answer(prompt: str, max_length: int = 256) -> str:
gen = load_generator()
out = gen(prompt, max_length=max_length, do_sample=False)
return out[0]["generated_text"]
# -------------------------
# Vector store wrapper: Qdrant (preferred) or FAISS (fallback)
# -------------------------
class VectorStore:
def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
raise NotImplementedError()
def query(self, embedding: np.ndarray, top_k: int = 5) -> List[Tuple[str, float, str, dict]]:
"""Return list of (id, score, text, metadata)"""
raise NotImplementedError()
def persist(self, path: str):
pass
# Qdrant store
class QdrantStore(VectorStore):
def __init__(self, collection_name="docs", host=None, port=None, prefer_grpc=False):
# host expected like "http://localhost:6333" or host + port
q_host = os.environ.get("QDRANT_URL") or host
api_key = os.environ.get("QDRANT_API_KEY")
if q_host:
# if full url provided, QdrantClient accepts url param
if q_host.startswith("http"):
self.client = QdrantClient(url=q_host, api_key=api_key)
else:
# assume host & port separated
self.client = QdrantClient(host=q_host, port=port or 6333, api_key=api_key)
else:
raise ValueError("Qdrant URL not provided for QdrantStore")
self.collection_name = collection_name
# ensure collection exists
try:
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE) # 384 for MiniLM; adjust if using different embed dim
)
except Exception:
# maybe already exists; ignore
pass
def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
points = []
for i, uid in enumerate(ids):
points.append({"id": uid, "vector": embeddings[i].tolist(), "payload": {"meta": metadatas[i], "text": texts[i]}})
self.client.upsert(collection_name=self.collection_name, points=points)
def query(self, embedding: np.ndarray, top_k: int = 5):
hits = self.client.search(collection_name=self.collection_name, query_vector=embedding.tolist(), limit=top_k)
results = []
for h in hits:
metadata = h.payload.get("meta", {})
text = h.payload.get("text", "")
results.append((str(h.id), float(h.score), text, metadata))
return results
# FAISS fallback (in-memory)
class FAISSStore(VectorStore):
def __init__(self, dim: int = 384):
self.dim = dim
self.index = faiss.IndexFlatIP(dim) # inner product (we will normalize)
self.texts = []
self.metadatas = []
self.ids = []
def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
# normalize embeddings for cosine via inner product
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms==0] = 1.0
emb_norm = embeddings / norms
self.index.add(emb_norm.astype('float32'))
self.texts.extend(texts)
self.metadatas.extend(metadatas)
self.ids.extend(ids)
def query(self, embedding: np.ndarray, top_k: int = 5):
emb = embedding.reshape(1, -1)
norm = np.linalg.norm(emb)
if norm == 0:
norm = 1.0
emb = emb / norm
D, I = self.index.search(emb.astype('float32'), k=top_k)
results = []
for score, idx in zip(D[0], I[0]):
if idx < 0 or idx >= len(self.texts):
continue
results.append((self.ids[idx], float(score), self.texts[idx], self.metadatas[idx]))
return results
# Utility to create appropriate store
def get_vector_store(prefer_qdrant=True, qdrant_collection="docs", embed_dim=384):
qdrant_url = os.environ.get("QDRANT_URL")
if prefer_qdrant and qdrant_url:
try:
return QdrantStore(collection_name=qdrant_collection)
except Exception as e:
print("Qdrant connection failed; falling back to FAISS. Error:", e)
# fallback
return FAISSStore(dim=embed_dim)
# -------------------------
# Building knowledge base: takes document text, chunks, embeds, and stores; returns ids
# -------------------------
def build_doc_store(text: str, store: VectorStore, chunk_size=1000, overlap=200, source_name="uploaded_doc"):
chunks = chunk_text(text, chunk_size=chunk_size, overlap=overlap)
if not chunks:
return []
embeddings = embed_texts(chunks)
ids = [str(uuid.uuid4()) for _ in chunks]
metadatas = [{"source": source_name, "chunk_index": i} for i in range(len(chunks))]
store.add(ids=ids, embeddings=embeddings, metadatas=metadatas, texts=chunks)
return [{"id": _id, "text": t, "metadata": m} for _id, t, m in zip(ids, chunks, metadatas)]