example_test / app.py
Wenye He
Update app.py
9515591 verified
# app.py
import gradio as gr
import torch
import time
import os
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_CONFIG = {
"phi-3": {
"model_name": "microsoft/phi-3-mini-4k-instruct",
"template": """<|user|>
Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly.
Context:
{context}
Question: {question}<|end|>
<|assistant|>
Based on the provided context, I'll answer your question:"""
},
"llama3-8b": {
"model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
"template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly.
Context:
{context}
Question: {question}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Based on the provided context, I'll answer your question:"""
}
}
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
class ChatModel:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.current_store = None
self.current_vectorstore = None
# Use the same embedding model as in vector store creation
# self.embeddings = HuggingFaceEmbeddings(
# model_name="sentence-transformers/all-MiniLM-L6-v2"
# )
def load_model(self, model_name):
"""Load and cache the model and tokenizer"""
if model_name not in self.models:
logger.info(f"Loading model: {model_name}")
try:
config = MODEL_CONFIG[model_name]
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
config["model_name"],
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
)
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
logger.info(f"Successfully loaded model: {model_name}")
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
raise
def load_vector_store(self, store_name):
"""Load vector store with cache invalidation"""
try:
# Check if we need to load a new store
if self.current_store != store_name:
logger.info(f"Loading new vector store: {store_name}")
vector_store_path = f"vector_stores_index/{store_name}"
logger.info(f"vector store path: {vector_store_path}")
if not os.path.exists(vector_store_path):
raise ValueError(f"Vector store not found at: {vector_store_path}")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Load new vector store
self.current_vectorstore = FAISS.load_local(
vector_store_path,
embeddings,
allow_dangerous_deserialization=True
)
self.current_store = store_name
# Verify the new store
self.check_vectorstore()
logger.info(f"Successfully loaded vector store: {store_name}")
return self.current_vectorstore
except Exception as e:
logger.error(f"Error loading vector store {store_name}: {str(e)}")
# Reset state on error
self.current_store = None
self.current_vectorstore = None
raise
def check_vectorstore(self):
"""Verify current vector store content"""
try:
if self.current_vectorstore is None:
raise ValueError("No vector store currently loaded")
# Use a generic query to test retrieval
sample_query = "what is this document about"
docs = self.current_vectorstore.similarity_search(sample_query, k=1)
logger.info(f"Vector store {self.current_store} content sample:")
logger.info(f"Document content: {docs[0].page_content[:200]}...")
logger.info(f"Document source: {docs[0].metadata.get('source', 'unknown')}")
except Exception as e:
logger.error(f"Error checking vector store: {str(e)}")
raise
def generate(self, message, model_name, vector_store_name, history):
"""Generate response using RAG"""
start_time = time.time()
try:
# Load model and vector store
self.load_model(model_name)
self.load_vector_store(vector_store_name)
config = MODEL_CONFIG[model_name]
# Retrieve relevant context
logger.info(f"Retrieving context for query: {message}")
docs = self.current_vectorstore.similarity_search(message, k=3)
# Log retrieved documents for debugging
for i, doc in enumerate(docs):
logger.info(f"Retrieved document {i + 1}:")
logger.info(f"Source: {doc.metadata.get('source', 'unknown')}")
logger.info(f"Content: {doc.page_content[:200]}...")
context = "\n\n".join([d.page_content for d in docs])
# Format prompt
prompt = config["template"].format(
context=context,
question=message
)
logger.info(f"Generated prompt: {prompt[:200]}...")
# Generate response
pipe = pipeline(
"text-generation",
model=self.models[model_name],
tokenizer=self.tokenizers[model_name],
max_new_tokens=384,
temperature=0.3, # Lower temperature for more focused responses
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
return_full_text=False
)
response = pipe(prompt)[0]['generated_text']
# Calculate metrics
elapsed_time = time.time() - start_time
tokens = len(self.tokenizers[model_name].encode(response))
tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
logger.info(f"Generated response in {elapsed_time:.2f}s")
return response, elapsed_time, tokens_per_sec
except Exception as e:
logger.error(f"Error in generate: {str(e)}")
raise
# Initialize model handler
model_handler = ChatModel()
def chat(message, history, model_choice, vector_store_choice):
"""Chat interface function"""
logger.info(f"Received message: {message}")
logger.info(f"Using model: {model_choice}")
logger.info(f"Using vector store: {vector_store_choice}")
try:
response, response_time, token_speed = model_handler.generate(
message,
model_choice,
vector_store_choice,
history
)
# Format response with metrics and source context
formatted_response = (
f"{response}\n\n"
f"⏱️ Response Time: {response_time:.2f}s | "
f"🚀 Speed: {token_speed:.2f} tokens/s"
)
return [(message, formatted_response)]
except Exception as e:
logger.error(f"Error in chat: {str(e)}")
error_message = f"Error: {str(e)}\n\nPlease try again or contact support if the issue persists."
return [(message, error_message)]
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""# 🚀 Enhanced RAG Chatbot with Performance Metrics
This chatbot uses Retrieval-Augmented Generation (RAG) to provide informed responses based on your documents.
""")
with gr.Row():
model_choice = gr.Dropdown(
choices=["phi-3", "llama3-8b"],
label="Select Model",
value="phi-3"
)
vector_store_choice = gr.Dropdown(
["llm", "scoliosis"], # Update these choices based on your vector stores
value="llm",
label="Knowledge Base",
interactive=True
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(
label="Message",
placeholder="Type your question here...",
scale=4
)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.ClearButton([msg, chatbot])
# Event handlers
msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
if __name__ == "__main__":
demo.launch()