sofzcc commited on
Commit
28c97dd
·
verified ·
1 Parent(s): 3f19cda

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import yaml
4
+ from typing import List, Tuple
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import gradio as gr
9
+ from sentence_transformers import SentenceTransformer
10
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
11
+ from PyPDF2 import PdfReader
12
+ import docx
13
+
14
+
15
+ # -----------------------------
16
+ # CONFIG
17
+ # -----------------------------
18
+
19
+ with open("config.yaml", "r", encoding="utf-8") as f:
20
+ CONFIG = yaml.safe_load(f)
21
+
22
+ KB_DIR = CONFIG["kb"]["directory"]
23
+ INDEX_DIR = CONFIG["kb"]["index_directory"]
24
+ EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
25
+ QA_MODEL_NAME = CONFIG["models"]["qa"]
26
+ CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
27
+ CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
28
+ SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
29
+ WELCOME_MSG = CONFIG["messages"]["welcome"]
30
+ NO_ANSWER_MSG = CONFIG["messages"]["no_answer"]
31
+
32
+
33
+ # -----------------------------
34
+ # UTILITIES
35
+ # -----------------------------
36
+
37
+ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
38
+ if not text:
39
+ return []
40
+ chunks = []
41
+ start = 0
42
+ while start < len(text):
43
+ end = min(start + chunk_size, len(text))
44
+ chunk = text[start:end].strip()
45
+ if chunk:
46
+ chunks.append(chunk)
47
+ start += chunk_size - overlap
48
+ return chunks
49
+
50
+
51
+ def load_file_text(path: str) -> str:
52
+ ext = os.path.splitext(path)[1].lower()
53
+ if ext == ".pdf":
54
+ reader = PdfReader(path)
55
+ return "\n".join(page.extract_text() or "" for page in reader.pages)
56
+ elif ext in [".docx", ".doc"]:
57
+ doc = docx.Document(path)
58
+ return "\n".join(p.text for p in doc.paragraphs)
59
+ else:
60
+ with open(path, "r", encoding="utf-8") as f:
61
+ return f.read()
62
+
63
+
64
+ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
65
+ docs = []
66
+ if os.path.isdir(kb_dir):
67
+ paths = glob.glob(os.path.join(kb_dir, "*.txt")) \
68
+ + glob.glob(os.path.join(kb_dir, "*.md")) \
69
+ + glob.glob(os.path.join(kb_dir, "*.pdf")) \
70
+ + glob.glob(os.path.join(kb_dir, "*.docx"))
71
+ for path in paths:
72
+ try:
73
+ text = load_file_text(path)
74
+ if text.strip():
75
+ docs.append((os.path.basename(path), text))
76
+ except Exception as e:
77
+ print(f"Could not read {path}: {e}")
78
+ return docs
79
+
80
+
81
+ # -----------------------------
82
+ # KB INDEX (FAISS)
83
+ # -----------------------------
84
+
85
+ class RAGIndex:
86
+ def __init__(self):
87
+ print("Loading embedding model...")
88
+ self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
89
+ print("Loading QA model...")
90
+ self.qa_pipeline = pipeline(
91
+ "question-answering",
92
+ model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME),
93
+ tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME),
94
+ handle_impossible_answer=True,
95
+ )
96
+ self.chunks: List[str] = []
97
+ self.chunk_sources: List[str] = []
98
+ self.index = None
99
+ self._build_or_load_index()
100
+
101
+ def _build_or_load_index(self):
102
+ os.makedirs(INDEX_DIR, exist_ok=True)
103
+ idx_path = os.path.join(INDEX_DIR, "kb.index")
104
+ meta_path = os.path.join(INDEX_DIR, "kb_meta.npy")
105
+
106
+ if os.path.exists(idx_path) and os.path.exists(meta_path):
107
+ print("Loading existing FAISS index...")
108
+ self.index = faiss.read_index(idx_path)
109
+ meta = np.load(meta_path, allow_pickle=True).item()
110
+ self.chunks = meta["chunks"]
111
+ self.chunk_sources = meta["sources"]
112
+ print("Index loaded.")
113
+ return
114
+
115
+ print("Building new FAISS index...")
116
+ docs = load_kb_documents(KB_DIR)
117
+ all_chunks = []
118
+ all_sources = []
119
+ for source, text in docs:
120
+ for chunk in chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP):
121
+ all_chunks.append(chunk)
122
+ all_sources.append(source)
123
+
124
+ if not all_chunks:
125
+ print("⚠️ No KB documents found, index will stay empty.")
126
+ self.index = None
127
+ return
128
+
129
+ embeddings = self.embedder.encode(all_chunks, show_progress_bar=True, convert_to_numpy=True)
130
+ dimension = embeddings.shape[1]
131
+ index = faiss.IndexFlatIP(dimension)
132
+
133
+ # Normalize for cosine similarity
134
+ faiss.normalize_L2(embeddings)
135
+ index.add(embeddings)
136
+
137
+ faiss.write_index(index, idx_path)
138
+ np.save(meta_path, {"chunks": np.array(all_chunks, dtype=object), "sources": np.array(all_sources, dtype=object)})
139
+
140
+ self.index = index
141
+ self.chunks = all_chunks
142
+ self.chunk_sources = all_sources
143
+ print("FAISS index ready.")
144
+
145
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
146
+ if not query.strip() or self.index is None:
147
+ return []
148
+ q_emb = self.embedder.encode([query], convert_to_numpy=True)
149
+ faiss.normalize_L2(q_emb)
150
+ scores, idxs = self.index.search(q_emb, top_k)
151
+ results = []
152
+ for score, idx in zip(scores[0], idxs[0]):
153
+ if idx == -1:
154
+ continue
155
+ if score < SIM_THRESHOLD:
156
+ continue
157
+ results.append((self.chunks[idx], self.chunk_sources[idx], float(score)))
158
+ return results
159
+
160
+ def answer(self, question: str) -> str:
161
+ contexts = self.retrieve(question, top_k=3)
162
+ if not contexts:
163
+ return NO_ANSWER_MSG
164
+
165
+ answers = []
166
+ for ctx, source, score in contexts:
167
+ qa_input = {"question": question, "context": ctx}
168
+ try:
169
+ result = self.qa_pipeline(qa_input)
170
+ text = result.get("answer", "").strip()
171
+ if text:
172
+ answers.append((text, source, result.get("score", 0.0)))
173
+ except Exception as e:
174
+ print(f"QA error: {e}")
175
+
176
+ if not answers:
177
+ return NO_ANSWER_MSG
178
+
179
+ # Pick best answer
180
+ answers.sort(key=lambda x: x[2], reverse=True)
181
+ best_answer, best_source, best_score = answers[0]
182
+
183
+ return (
184
+ f"**Answer:** {best_answer}\n\n"
185
+ f"**Source:** {best_source} (confidence: {best_score:.2f})"
186
+ )
187
+
188
+
189
+ rag_index = RAGIndex()
190
+
191
+
192
+ # -----------------------------
193
+ # GRADIO CHAT
194
+ # -----------------------------
195
+
196
+ def rag_respond(message: str, history):
197
+ return rag_index.answer(message)
198
+
199
+
200
+ description = CONFIG["messages"]["welcome"]
201
+
202
+ chat = gr.ChatInterface(
203
+ fn=rag_respond,
204
+ title=CONFIG["client"]["name"],
205
+ description=description,
206
+ type="messages",
207
+ examples=[qa["query"] for qa in CONFIG.get("quick_actions", [])],
208
+ cache_examples=False,
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ chat.launch()