brahmanarisetty commited on
Commit
5bdd4be
·
verified ·
1 Parent(s): 05a645e

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +342 -0
  3. data.csv +3 -0
  4. requirements.txt +20 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ IT Support Chatbot (Hugging Face Spaces)
4
+ - Matches Colab pipeline with Hybrid Retrieval (Dense + BM25) and Reranking
5
+ - Uses Qdrant as vector store (build or serve depending on BUILD_MODE)
6
+ - Embeddings kept consistent across build & query via EMBED_MODEL_ID
7
+ - GPU/CPU-safe LLaMA loading (4-bit on GPU, smaller instruct model on CPU)
8
+ - Minimal Gradio UI (Chat + Clear), optional context viewer
9
+
10
+ Environment variables (Spaces → Settings → Variables):
11
+ QDRANT_HOST, QDRANT_API_KEY, HF_TOKEN
12
+ EMBED_MODEL_ID (default: BAAI/bge-large-en-v1.5)
13
+ QDRANT_COLLECTION (default: it_support_rag)
14
+ MODEL_ID (default: meta-llama/Llama-3.1-8B-Instruct)
15
+ CPU_MODEL_ID (default: meta-llama/Llama-3.2-3B-Instruct)
16
+ BUILD_MODE ("true" to build/rebuild from data.csv; default: "false")
17
+ OMP_NUM_THREADS (default: "1")
18
+ SHOW_CONTEXT ("true" to show retrieved context; default: "true")
19
+ """
20
+
21
+ # --- Imports & setup ---
22
+ import os
23
+ import random
24
+ import logging
25
+ import numpy as np
26
+ import torch
27
+ import nest_asyncio
28
+ import pandas as pd
29
+ import gradio as gr
30
+ from typing import List
31
+
32
+ from huggingface_hub import login
33
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
34
+
35
+ from llama_index.core import (
36
+ VectorStoreIndex, StorageContext, Settings, QueryBundle, Document
37
+ )
38
+ from llama_index.core.node_parser import SentenceSplitter
39
+ from llama_index.core.retrievers import BaseRetriever
40
+ from llama_index.core.postprocessor import SentenceTransformerRerank
41
+ from llama_index.core.query_engine import RetrieverQueryEngine
42
+
43
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
44
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
45
+ from llama_index.retrievers.bm25 import BM25Retriever
46
+
47
+ import qdrant_client
48
+
49
+ # --- Logging ---
50
+ logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=logging.INFO)
51
+ logger = logging.getLogger("it_support_app")
52
+
53
+ # --- Reproducibility & asyncio ---
54
+ SEED = 42
55
+ random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
56
+ nest_asyncio.apply()
57
+
58
+ # --- Env vars & sane defaults ---
59
+ os.environ.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "1"))
60
+
61
+ QDRANT_HOST = os.getenv("QDRANT_HOST")
62
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
63
+ HF_TOKEN = os.getenv("HF_TOKEN")
64
+
65
+ EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "BAAI/bge-large-en-v1.5")
66
+ COLLECTION_NAME = os.getenv("QDRANT_COLLECTION", "it_support_rag")
67
+ BUILD_MODE = os.getenv("BUILD_MODE", "false").lower() == "true"
68
+ SHOW_CONTEXT = os.getenv("SHOW_CONTEXT", "true").lower() == "true"
69
+
70
+ GPU_MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct")
71
+ CPU_MODEL_ID = os.getenv("CPU_MODEL_ID", "meta-llama/Llama-3.2-3B-Instruct")
72
+
73
+ if not all([QDRANT_HOST, QDRANT_API_KEY, HF_TOKEN]):
74
+ raise EnvironmentError("Set QDRANT_HOST, QDRANT_API_KEY, and HF_TOKEN in Space variables.")
75
+
76
+ # --- Auth & clients ---
77
+ login(token=HF_TOKEN)
78
+ qdrant = qdrant_client.QdrantClient(url=QDRANT_HOST, api_key=QDRANT_API_KEY, prefer_grpc=False)
79
+
80
+ # --- Embeddings (keep consistent across build & serve) ---
81
+ Settings.embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
82
+ logger.info(f"✅ Embedding model set: {EMBED_MODEL_ID}")
83
+
84
+ # --- Node parser (token-ish chunks) ---
85
+ node_parser = SentenceSplitter(chunk_size=1024, chunk_overlap=100, paragraph_separator="\n\n")
86
+
87
+ # --- Optional: load CSV for BM25 and/or BUILD_MODE ---
88
+ CSV_PATH = "data.csv"
89
+ case_docs: List[Document] = []
90
+ bm25_retriever = None
91
+
92
+ if os.path.exists(CSV_PATH):
93
+ try:
94
+ df = pd.read_csv(CSV_PATH, encoding="ISO-8859-1")
95
+ for _, row in df.iterrows():
96
+ text = str(row.get("text_chunk", ""))
97
+ meta = {
98
+ "source_dataset": str(row.get("source_dataset", ""))[:50],
99
+ "category": str(row.get("category", ""))[:100],
100
+ "orig_query": str(row.get("original_query", ""))[:200],
101
+ "orig_solution": str(row.get("original_solution", ""))[:200],
102
+ }
103
+ case_docs.append(Document(text=text, metadata=meta))
104
+ logger.info(f"Loaded {len(case_docs)} documents from {CSV_PATH}.")
105
+
106
+ # BM25 (optional; uses local docs only)
107
+ bm25_nodes = node_parser.get_nodes_from_documents(case_docs)
108
+ bm25_retriever = BM25Retriever.from_defaults(nodes=bm25_nodes, similarity_top_k=10)
109
+ logger.info("✅ BM25 retriever initialized.")
110
+ except Exception as e:
111
+ logger.warning(f"BM25 setup skipped due to error: {e}")
112
+ else:
113
+ logger.warning("data.csv not found — proceeding WITHOUT BM25 (dense-only).")
114
+
115
+ # --- Qdrant vector store & index ---
116
+ vector_store = QdrantVectorStore(client=qdrant, collection_name=COLLECTION_NAME, prefer_grpc=False)
117
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
118
+
119
+ if BUILD_MODE:
120
+ if not case_docs:
121
+ raise FileNotFoundError(
122
+ "BUILD_MODE=true but data.csv is missing or empty. "
123
+ "Commit data.csv to the Space repo or disable BUILD_MODE."
124
+ )
125
+ logger.info(f"BUILD_MODE=true → indexing {len(case_docs)} docs into Qdrant collection '{COLLECTION_NAME}'")
126
+ index = VectorStoreIndex.from_documents(
127
+ documents=case_docs,
128
+ storage_context=storage_context,
129
+ embed_model=Settings.embed_model,
130
+ node_parser=node_parser,
131
+ )
132
+ else:
133
+ index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
134
+ logger.info(f"✅ Loaded existing index from Qdrant collection '{COLLECTION_NAME}'")
135
+
136
+ # --- Dense retriever + hybrid wrapper ---
137
+ dense_retriever = index.as_retriever(similarity_top_k=10)
138
+
139
+ class HybridRetriever(BaseRetriever):
140
+ def __init__(self, dense, bm25=None, top_k=10):
141
+ super().__init__()
142
+ self.dense = dense
143
+ self.bm25 = bm25
144
+ self.top_k = top_k
145
+
146
+ def _retrieve(self, query_bundle: QueryBundle):
147
+ dense_hits = []
148
+ try:
149
+ dense_hits = self.dense.retrieve(query_bundle)
150
+ except Exception as e:
151
+ logger.error(f"Dense retrieval error: {e}")
152
+
153
+ bm25_hits = []
154
+ if self.bm25:
155
+ try:
156
+ bm25_hits = self.bm25.retrieve(query_bundle)
157
+ except Exception as e:
158
+ logger.warning(f"BM25 retrieval error: {e}")
159
+
160
+ # Merge & de-duplicate by node_id
161
+ combined = dense_hits + bm25_hits
162
+ unique, seen = [], set()
163
+ for hit in combined:
164
+ nid = hit.node.node_id
165
+ if nid not in seen:
166
+ seen.add(nid); unique.append(hit)
167
+ return unique[: self.top_k]
168
+
169
+ hybrid_retriever = HybridRetriever(dense=dense_retriever, bm25=bm25_retriever, top_k=10)
170
+
171
+ # --- Reranker ---
172
+ reranker = SentenceTransformerRerank(
173
+ model="cross-encoder/ms-marco-MiniLM-L-2-v2",
174
+ top_n=4,
175
+ device=("cuda" if torch.cuda.is_available() else "cpu")
176
+ )
177
+
178
+ # --- Query Engine (use the hybrid retriever) ---
179
+ query_engine = RetrieverQueryEngine(retriever=hybrid_retriever, node_postprocessors=[reranker])
180
+
181
+ # --- LLM loading (GPU: 4-bit 8B; CPU: smaller instruct model) ---
182
+ use_cuda = torch.cuda.is_available()
183
+ if use_cuda:
184
+ quant_config = BitsAndBytesConfig(
185
+ load_in_4bit=True,
186
+ bnb_4bit_quant_type="nf4",
187
+ bnb_4bit_use_double_quant=True,
188
+ bnb_4bit_compute_dtype=torch.bfloat16,
189
+ )
190
+ tokenizer = AutoTokenizer.from_pretrained(GPU_MODEL_ID, use_fast=True)
191
+ llm = AutoModelForCausalLM.from_pretrained(GPU_MODEL_ID, quantization_config=quant_config, device_map="auto")
192
+ generator = pipeline("text-generation", model=llm, tokenizer=tokenizer)
193
+ logger.info(f"✅ Loaded GPU model in 4-bit: {GPU_MODEL_ID}")
194
+ else:
195
+ tokenizer = AutoTokenizer.from_pretrained(CPU_MODEL_ID, use_fast=True)
196
+ llm = AutoModelForCausalLM.from_pretrained(CPU_MODEL_ID)
197
+ generator = pipeline("text-generation", model=llm, tokenizer=tokenizer, device=-1)
198
+ logger.info(f"✅ Loaded CPU model: {CPU_MODEL_ID}")
199
+
200
+ # --- Prompt scaffolding ---
201
+ SYSTEM_PROMPT = (
202
+ "You are a friendly and helpful Level 0 IT Support Assistant. "
203
+ "Use a conversational tone and guide users step-by-step. "
204
+ "If the user's question lacks details or clarity, ask a concise follow-up question "
205
+ "to gather the information you need before providing a solution. "
206
+ "Once clarified, then:\n"
207
+ "1) Diagnose the problem.\n"
208
+ "2) Provide step-by-step solutions with bullet points.\n"
209
+ "3) Offer additional recommendations or safety warnings.\n"
210
+ "4) End with a polite closing.\n"
211
+ "5) If it is out of level 0 IT support, direct users to contact IT support."
212
+ )
213
+
214
+ HDR = {
215
+ "sys": "<|start_header_id|>system<|end_header_id|>",
216
+ "usr": "<|start_header_id|>user<|end_header_id|>",
217
+ "ast": "<|start_header_id|>assistant<|end_header_id|>",
218
+ "eot": "<|eot_id|>",
219
+ }
220
+
221
+ chat_history = []
222
+ GREETINGS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
223
+
224
+ def format_history(history):
225
+ return "".join(
226
+ f"{HDR['usr']}\n{u}{HDR['eot']}{HDR['ast']}\n{a}{HDR['eot']}"
227
+ for u, a in history
228
+ )
229
+
230
+ def _nodes_to_text(nodes):
231
+ parts = []
232
+ for i, n in enumerate(nodes or []):
233
+ score = getattr(n, "score", 0.0)
234
+ text = n.node.get_content() if hasattr(n, "node") else n.get_content()
235
+ parts.append(f"**Source {i+1} (Score: {score:.4f})**\n{text}")
236
+ return "\n\n---\n\n".join(parts) if parts else ""
237
+
238
+ def build_prompt(query, context_nodes, history):
239
+ q = query.strip()
240
+ if q.lower() in GREETINGS:
241
+ return None, "greeting"
242
+ if len(q.split()) < 3:
243
+ return (
244
+ "Could you provide more detail about what you're experiencing? "
245
+ "Any error messages or steps you've tried will help me assist you."
246
+ ), "clarify"
247
+
248
+ ctx_text = "\n---\n".join(
249
+ (n.node.get_content() if hasattr(n, "node") else n.get_content())
250
+ for n in (context_nodes or [])
251
+ ) or "No context provided."
252
+ hist_str = format_history(history[-3:])
253
+ prompt = (
254
+ "<|begin_of_text|>"
255
+ f"{HDR['sys']}\n{SYSTEM_PROMPT}{HDR['eot']}"
256
+ f"{hist_str}"
257
+ f"{HDR['usr']}\nContext:\n{ctx_text}{HDR['eot']}"
258
+ f"{HDR['usr']}\nQuestion: {q}{HDR['eot']}"
259
+ f"{HDR['ast']}\n"
260
+ )
261
+ return prompt, "rag"
262
+
263
+ def chat(query, temperature=0.7, top_p=0.9, max_new_tokens=350):
264
+ global chat_history
265
+ # Pre-check (greeting/clarify)
266
+ prompt, mode = build_prompt(query, [], chat_history)
267
+ if mode == "greeting":
268
+ reply = "Hello there! How can I help with your IT support question today?"
269
+ chat_history.append((query, reply))
270
+ return reply, []
271
+ if mode == "clarify":
272
+ reply = prompt
273
+ chat_history.append((query, reply))
274
+ return reply, []
275
+
276
+ # Retrieve → Rerank → Build prompt with context → Generate
277
+ response = query_engine.query(query)
278
+ context_nodes = response.source_nodes
279
+ prompt, _ = build_prompt(query, context_nodes, chat_history)
280
+ gen_args = {
281
+ "do_sample": True,
282
+ "max_new_tokens": max_new_tokens,
283
+ "temperature": temperature,
284
+ "top_p": top_p,
285
+ "eos_token_id": tokenizer.eos_token_id,
286
+ }
287
+ out = generator(prompt, **gen_args)
288
+ text = out[0]["generated_text"]
289
+ answer = text.split(HDR["ast"])[-1].strip()
290
+ chat_history.append((query, answer))
291
+ return answer, context_nodes
292
+
293
+ # --- Gradio UI (minimal; optional context viewer) ---
294
+ with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot") as demo:
295
+ gr.Markdown("### 🤖 Level 0 IT Support Chatbot (RAG + Qdrant + LLaMA3)")
296
+
297
+ with gr.Row():
298
+ with gr.Column(scale=3):
299
+ chatbot = gr.Chatbot(label="Chat", height=500, bubble_full_width=False)
300
+ inp = gr.Textbox(placeholder="Ask your IT support question...", label="Your Message", lines=2)
301
+ with gr.Row():
302
+ send_btn = gr.Button("Send", variant="primary")
303
+ clear_btn = gr.Button("Clear", variant="secondary")
304
+ if SHOW_CONTEXT:
305
+ with gr.Column(scale=1):
306
+ with gr.Accordion("Show Retrieved Context", open=False):
307
+ context_box = gr.Markdown(value="")
308
+
309
+ def respond(message, history):
310
+ # Fixed defaults; keep UI minimal (like your Colab)
311
+ reply, context_nodes = chat(message, temperature=0.7, top_p=0.9)
312
+ history = history or []
313
+ history.append([message, reply])
314
+ if SHOW_CONTEXT:
315
+ return "", history, _nodes_to_text(context_nodes)
316
+ else:
317
+ return "", history
318
+
319
+ def clear_chat():
320
+ global chat_history
321
+ chat_history = []
322
+ if SHOW_CONTEXT:
323
+ return [], ""
324
+ else:
325
+ return []
326
+
327
+ if SHOW_CONTEXT:
328
+ inp.submit(respond, [inp, chatbot], [inp, chatbot, context_box])
329
+ send_btn.click(respond, [inp, chatbot], [inp, chatbot, context_box])
330
+ clear_btn.click(clear_chat, None, [chatbot, context_box], queue=False)
331
+ else:
332
+ inp.submit(respond, [inp, chatbot], [inp, chatbot])
333
+ send_btn.click(respond, [inp, chatbot], [inp, chatbot])
334
+ clear_btn.click(clear_chat, None, [chatbot], queue=False)
335
+
336
+ # Keep the UI responsive on Spaces
337
+ demo.queue(concurrency_count=2, max_size=32)
338
+
339
+ if __name__ == "__main__":
340
+ logger.info("Launching Gradio interface...")
341
+ # On Spaces, these are auto-handled; still safe to specify:
342
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c181a92f7d7a203f66e535021210625cc7bf34afb56ccab94d2a5daf537215
3
+ size 21023207
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llama-index-core
2
+ llama-index-vector-stores-qdrant
3
+ llama-index-embeddings-huggingface
4
+ llama-index-retrievers-bm25
5
+ llama-index-llms-huggingface
6
+ sentence-transformers
7
+ transformers
8
+ accelerate
9
+ gradio
10
+ qdrant-client
11
+ bitsandbytes
12
+ rouge-score
13
+ bert-score
14
+ evaluate
15
+ nest_asyncio
16
+ torch
17
+ pandas
18
+ numpy
19
+ tf-keras
20
+ python-dotenv