MansoorSarookh commited on
Commit
211efc1
·
verified ·
1 Parent(s): 87ecdf6

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +211 -0
utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import os
3
+ import re
4
+ from io import BytesIO
5
+ from typing import List, Tuple, Dict, Optional
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
+ import numpy as np
9
+ from pypdf import PdfReader
10
+ import docx
11
+ from tqdm.auto import tqdm
12
+
13
+ # Vector store compatibility imports
14
+ from qdrant_client import QdrantClient
15
+ from qdrant_client.http.models import VectorParams, Distance
16
+ import faiss
17
+ import uuid
18
+ import pickle
19
+
20
+ # -------------------------
21
+ # Document parsing
22
+ # -------------------------
23
+ def extract_text_from_pdf(file_bytes: bytes) -> str:
24
+ reader = PdfReader(BytesIO(file_bytes))
25
+ texts = []
26
+ for page in reader.pages:
27
+ try:
28
+ texts.append(page.extract_text() or "")
29
+ except Exception:
30
+ texts.append("")
31
+ return "\n".join(texts)
32
+
33
+ def extract_text_from_docx(file_bytes: bytes) -> str:
34
+ f = BytesIO(file_bytes)
35
+ doc = docx.Document(f)
36
+ paragraphs = [p.text for p in doc.paragraphs]
37
+ return "\n".join(paragraphs)
38
+
39
+ def extract_text(filename: str, bytestr: bytes) -> str:
40
+ ext = filename.lower().split('.')[-1]
41
+ if ext == "pdf":
42
+ return extract_text_from_pdf(bytestr)
43
+ elif ext in ("docx", "doc"):
44
+ return extract_text_from_docx(bytestr)
45
+ else:
46
+ raise ValueError(f"Unsupported file type: {ext}")
47
+
48
+ # -------------------------
49
+ # Chunking (simple char-based chunks with overlap)
50
+ # -------------------------
51
+ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
52
+ if not text:
53
+ return []
54
+ text = re.sub(r'\n\s*\n', '\n', text) # collapse multiple blank lines
55
+ start = 0
56
+ chunks = []
57
+ L = len(text)
58
+ while start < L:
59
+ end = start + chunk_size
60
+ chunk = text[start:end]
61
+ chunks.append(chunk.strip())
62
+ start = end - overlap
63
+ if start < 0:
64
+ start = 0
65
+ return chunks
66
+
67
+ # -------------------------
68
+ # Embeddings (SentenceTransformer)
69
+ # -------------------------
70
+ EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
71
+ _embed_model = None
72
+
73
+ def load_embedding_model():
74
+ global _embed_model
75
+ if _embed_model is None:
76
+ _embed_model = SentenceTransformer(EMBED_MODEL_NAME)
77
+ return _embed_model
78
+
79
+ def embed_texts(texts: List[str]) -> np.ndarray:
80
+ model = load_embedding_model()
81
+ embeddings = model.encode(texts, show_progress_bar=False, convert_to_numpy=True)
82
+ return embeddings
83
+
84
+ # -------------------------
85
+ # Generator model (RAG prompt -> generate answer)
86
+ # -------------------------
87
+ # Use a lightweight seq2seq model that runs reasonably on CPU for small questions.
88
+ GEN_MODEL_NAME = os.environ.get("GEN_MODEL", "google/flan-t5-small")
89
+ _gen_pipeline = None
90
+
91
+ def load_generator():
92
+ global _gen_pipeline
93
+ if _gen_pipeline is None:
94
+ # Use Seq2SeqPipeline
95
+ _gen_pipeline = pipeline("text2text-generation", model=GEN_MODEL_NAME, tokenizer=GEN_MODEL_NAME, device=-1)
96
+ return _gen_pipeline
97
+
98
+ def generate_answer(prompt: str, max_length: int = 256) -> str:
99
+ gen = load_generator()
100
+ out = gen(prompt, max_length=max_length, do_sample=False)
101
+ return out[0]["generated_text"]
102
+
103
+ # -------------------------
104
+ # Vector store wrapper: Qdrant (preferred) or FAISS (fallback)
105
+ # -------------------------
106
+ class VectorStore:
107
+ def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
108
+ raise NotImplementedError()
109
+ def query(self, embedding: np.ndarray, top_k: int = 5) -> List[Tuple[str, float, str, dict]]:
110
+ """Return list of (id, score, text, metadata)"""
111
+ raise NotImplementedError()
112
+ def persist(self, path: str):
113
+ pass
114
+
115
+ # Qdrant store
116
+ class QdrantStore(VectorStore):
117
+ def __init__(self, collection_name="docs", host=None, port=None, prefer_grpc=False):
118
+ # host expected like "http://localhost:6333" or host + port
119
+ q_host = os.environ.get("QDRANT_URL") or host
120
+ api_key = os.environ.get("QDRANT_API_KEY")
121
+ if q_host:
122
+ # if full url provided, QdrantClient accepts url param
123
+ if q_host.startswith("http"):
124
+ self.client = QdrantClient(url=q_host, api_key=api_key)
125
+ else:
126
+ # assume host & port separated
127
+ self.client = QdrantClient(host=q_host, port=port or 6333, api_key=api_key)
128
+ else:
129
+ raise ValueError("Qdrant URL not provided for QdrantStore")
130
+ self.collection_name = collection_name
131
+ # ensure collection exists
132
+ try:
133
+ self.client.recreate_collection(
134
+ collection_name=self.collection_name,
135
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE) # 384 for MiniLM; adjust if using different embed dim
136
+ )
137
+ except Exception:
138
+ # maybe already exists; ignore
139
+ pass
140
+
141
+ def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
142
+ points = []
143
+ for i, uid in enumerate(ids):
144
+ points.append({"id": uid, "vector": embeddings[i].tolist(), "payload": {"meta": metadatas[i], "text": texts[i]}})
145
+ self.client.upsert(collection_name=self.collection_name, points=points)
146
+
147
+ def query(self, embedding: np.ndarray, top_k: int = 5):
148
+ hits = self.client.search(collection_name=self.collection_name, query_vector=embedding.tolist(), limit=top_k)
149
+ results = []
150
+ for h in hits:
151
+ metadata = h.payload.get("meta", {})
152
+ text = h.payload.get("text", "")
153
+ results.append((str(h.id), float(h.score), text, metadata))
154
+ return results
155
+
156
+ # FAISS fallback (in-memory)
157
+ class FAISSStore(VectorStore):
158
+ def __init__(self, dim: int = 384):
159
+ self.dim = dim
160
+ self.index = faiss.IndexFlatIP(dim) # inner product (we will normalize)
161
+ self.texts = []
162
+ self.metadatas = []
163
+ self.ids = []
164
+
165
+ def add(self, ids: List[str], embeddings: np.ndarray, metadatas: List[dict], texts: List[str]):
166
+ # normalize embeddings for cosine via inner product
167
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
168
+ norms[norms==0] = 1.0
169
+ emb_norm = embeddings / norms
170
+ self.index.add(emb_norm.astype('float32'))
171
+ self.texts.extend(texts)
172
+ self.metadatas.extend(metadatas)
173
+ self.ids.extend(ids)
174
+
175
+ def query(self, embedding: np.ndarray, top_k: int = 5):
176
+ emb = embedding.reshape(1, -1)
177
+ norm = np.linalg.norm(emb)
178
+ if norm == 0:
179
+ norm = 1.0
180
+ emb = emb / norm
181
+ D, I = self.index.search(emb.astype('float32'), k=top_k)
182
+ results = []
183
+ for score, idx in zip(D[0], I[0]):
184
+ if idx < 0 or idx >= len(self.texts):
185
+ continue
186
+ results.append((self.ids[idx], float(score), self.texts[idx], self.metadatas[idx]))
187
+ return results
188
+
189
+ # Utility to create appropriate store
190
+ def get_vector_store(prefer_qdrant=True, qdrant_collection="docs", embed_dim=384):
191
+ qdrant_url = os.environ.get("QDRANT_URL")
192
+ if prefer_qdrant and qdrant_url:
193
+ try:
194
+ return QdrantStore(collection_name=qdrant_collection)
195
+ except Exception as e:
196
+ print("Qdrant connection failed; falling back to FAISS. Error:", e)
197
+ # fallback
198
+ return FAISSStore(dim=embed_dim)
199
+
200
+ # -------------------------
201
+ # Building knowledge base: takes document text, chunks, embeds, and stores; returns ids
202
+ # -------------------------
203
+ def build_doc_store(text: str, store: VectorStore, chunk_size=1000, overlap=200, source_name="uploaded_doc"):
204
+ chunks = chunk_text(text, chunk_size=chunk_size, overlap=overlap)
205
+ if not chunks:
206
+ return []
207
+ embeddings = embed_texts(chunks)
208
+ ids = [str(uuid.uuid4()) for _ in chunks]
209
+ metadatas = [{"source": source_name, "chunk_index": i} for i in range(len(chunks))]
210
+ store.add(ids=ids, embeddings=embeddings, metadatas=metadatas, texts=chunks)
211
+ return [{"id": _id, "text": t, "metadata": m} for _id, t, m in zip(ids, chunks, metadatas)]