C2C_Chatbot / app.py
brahmanarisetty's picture
Update app.py
37b16ea verified
# --- Imports, Logging & Reproducibility ---
import os
import random
import logging
import numpy as np
import torch
import nest_asyncio
import pandas as pd
import gradio as gr
from typing import List
# Llama-Index & Transformers
from llama_index.core import (
VectorStoreIndex, StorageContext, Settings, QueryBundle
)
from llama_index.core.schema import Document
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.retrievers import BaseRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from huggingface_hub import login
import qdrant_client
from llama_index.core.query_engine import RetrieverQueryEngine
# Configure logging
logging.basicConfig(
format='%(asctime)s %(levelname)s: %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
nest_asyncio.apply()
# --- Hugging Face Spaces Configuration ---
# HF_TOKEN, QDRANT_HOST, and QDRANT_API_KEY should be set as Space Secrets
HF_TOKEN = os.getenv("HF_TOKEN")
QDRANT_HOST = os.getenv("QDRANT_HOST")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
os.environ['OMP_NUM_THREADS'] = '4'
if not QDRANT_HOST or not QDRANT_API_KEY or not HF_TOKEN:
raise EnvironmentError("Please set QDRANT_HOST, QDRANT_API_KEY, and HF_TOKEN as Space Secrets.")
login(token=HF_TOKEN)
# --- Qdrant Connection and Collection Setup ---
qdrant = qdrant_client.QdrantClient(
url=QDRANT_HOST,
api_key=QDRANT_API_KEY,
prefer_grpc=False
)
COLLECTION_NAME = "C2C_RAG"
# --- RAG Components Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-large-en-v1.5",
device=device
)
# This part needs to be pre-indexed or loaded differently
# For Spaces, you would typically pre-index the data
# and then load the index, but let's assume the collection exists.
# We'll just define a placeholder for nodes for the BM25 retriever.
bm25_nodes = [] # BM25 retriever requires nodes; in a Space, this is tricky.
# Qdrant-backed vector store (read-only for this case)
vector_store = QdrantVectorStore(
client=qdrant,
collection_name=COLLECTION_NAME,
prefer_grpc=False
)
# Load index from the existing vector store
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)
# --- Define Hybrid Retriever & Reranker ---
Settings.llm = None
class HybridRetriever(BaseRetriever):
def __init__(self, dense, bm25, similarity_top_k=10):
super().__init__()
self.dense = dense
self.bm25 = bm25
self.similarity_top_k = similarity_top_k
def _retrieve(self, query_bundle: QueryBundle) -> List[Document]:
dense_hits = []
try:
self.dense.similarity_top_k = self.similarity_top_k
dense_hits = self.dense.retrieve(query_bundle)
except Exception as e:
logger.error(f"Dense retrieval error: {e}")
bm25_hits = []
if self.bm25:
try:
self.bm25.similarity_top_k = self.similarity_top_k
bm25_hits = self.bm25.retrieve(query_bundle)
except Exception as e:
logger.warning(f"BM25 retrieval error: {e}")
combined = dense_hits + bm25_hits
unique = []
seen = set()
for hit in combined:
nid = hit.node.node_id
if nid not in seen:
seen.add(nid)
unique.append(hit)
return unique[:self.similarity_top_k]
# Instantiate retrievers
dense_retriever = index.as_retriever(similarity_top_k=10)
bm25_retriever = None
logger.warning("BM25 retriever is disabled as the original data is not available in the Space.")
hybrid_retriever = HybridRetriever(dense=dense_retriever, bm25=bm25_retriever)
reranker = SentenceTransformerRerank(
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
top_n=4
)
# CORRECTED: Remove the 'llm=None' argument
from llama_index.core.query_engine import RetrieverQueryEngine
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
node_postprocessors=[reranker],
)
# --- Load & Quantize LLaMA Model ---
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
llm = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=quant_config,
device_map="auto"
)
generator = pipeline(
task="text-generation",
model=llm,
tokenizer=tokenizer,
device_map="auto"
)
# --- Chatbot Logic & Gradio Interface (Improved) ---
SYSTEM_PROMPT = (
"You are a friendly and helpful Level 0 IT Support Assistant. "
"If the user's question lacks details or clarity, ask a concise follow-up question "
"to gather the information you need before providing a solution. "
"Once clarified, then:\n"
"Your purpose is to provide simple, step-by-step solutions for common, entry-level technical issues. "
"Examples of Level 0 issues include: forgotten passwords, basic printer problems, network connectivity checks, or simple software reinstallation. "
"Do not answer questions about booking tickets, Level 1 or Level 2 support, or advanced technical configurations. "
"If a user's question is beyond your scope (e.g., requires access to internal systems, involves advanced troubleshooting, or is not a basic IT issue), politely state that it's a higher-level issue and advise them to contact the dedicated IT support team directly. "
"Always maintain a conversational tone and end with a polite closing."
)
HDR = {
"sys": "<|start_header_id|>system<|end_header_id|>",
"usr": "<|start_header_id|>user<|end_header_id|>",
"ast": "<|start_header_id|>assistant<|end_header_id|>",
"eot": "<|eot_id|>"
}
chat_history = []
GREETINGS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
OUT_OF_SCOPE_KEYWORDS = ["book tickets", "level 1", "level 2", "advanced configuration", "request a laptop", "purchase software"]
def is_out_of_scope(query):
return any(keyword in query.lower() for keyword in OUT_OF_SCOPE_KEYWORDS)
def format_history(history):
return "".join(
f"{HDR['usr']}{u}{HDR['eot']}{HDR['ast']}{a}{HDR['eot']}"
for u, a in history
)
def chat(query, k, temperature, top_p):
global chat_history
if query.lower().strip() in GREETINGS:
reply = "Hello there! How can I help with your IT support question today?"
chat_history.append((query, reply))
return reply
words = query.strip().split()
if len(words) < 3:
reply = "Could you provide more detail about what you're experiencing? Any error messages or steps you've tried will help me assist you."
chat_history.append((query, reply))
return reply
if is_out_of_scope(query):
reply = "I apologize, but that seems to be a question for our dedicated IT support team. I can only assist with Level 0 issues like password resets or basic connectivity problems. Please contact them directly for help."
chat_history.append((query, reply))
return reply
query_engine.retriever.similarity_top_k = k
response = query_engine.query(query)
context_nodes = response.source_nodes
context_str = "\n---\n".join(node.text for node in context_nodes) if context_nodes else ""
hist_str = format_history(chat_history[-3:])
prompt = (
f"<|begin_of_text|>"
f"{HDR['sys']}{SYSTEM_PROMPT}{HDR['eot']}"
f"{hist_str}"
f"{HDR['usr']}Context:\n{context_str}{HDR['eot']}"
f"{HDR['usr']}Question: {query}{HDR['eot']}"
f"{HDR['ast']}"
)
gen_args = {
"do_sample": True,
"max_new_tokens": 356, # Now using the 356 token limit
"temperature": temperature,
"top_p": top_p,
"pad_token_id": tokenizer.eos_token_id
}
output = generator(prompt, **gen_args)
text = output[0]["generated_text"]
answer = text.split(HDR["ast"])[-1].strip()
chat_history.append((query, answer))
return answer
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot") as demo:
gr.Markdown("### 🤖 Level 0 IT Support Chatbot (RAG + Qdrant + LLaMA3)")
chatbot = gr.Chatbot(label="Chat", height=500)
state = gr.State([])
inp = gr.Textbox(placeholder="Ask your IT support question...", label="Your Message", lines=2)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Accordion("Advanced Settings", open=False):
k_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Context Hits (k)")
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p")
def respond(message, history, k_val, temp_val, top_p_val):
reply = chat(message, k_val, temp_val, top_p_val)
history.append([message, reply])
return "", history, history
inputs = [inp, state, k_slider, temp_slider, top_p_slider]
inp.submit(respond, inputs, [inp, chatbot, state])
send_btn.click(respond, inputs, [inp, chatbot, state])
clear_btn.click(lambda: ("", [], [], 10, 0.7, 0.9), None, [inp, chatbot, state, k_slider, temp_slider, top_p_slider], queue=False)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)