# app.py import gradio as gr import torch import time import os import logging from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig from langchain_community.document_loaders import PyPDFLoader, TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration MODEL_CONFIG = { "phi-3": { "model_name": "microsoft/phi-3-mini-4k-instruct", "template": """<|user|> Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly. Context: {context} Question: {question}<|end|> <|assistant|> Based on the provided context, I'll answer your question:""" }, "llama3-8b": { "model_name": "NousResearch/Meta-Llama-3-8B-Instruct", "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|> Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly. Context: {context} Question: {question}<|eot_id|> <|start_header_id|>assistant<|end_header_id|> Based on the provided context, I'll answer your question:""" } } bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True ) class ChatModel: def __init__(self): self.models = {} self.tokenizers = {} self.current_store = None self.current_vectorstore = None # Use the same embedding model as in vector store creation # self.embeddings = HuggingFaceEmbeddings( # model_name="sentence-transformers/all-MiniLM-L6-v2" # ) def load_model(self, model_name): """Load and cache the model and tokenizer""" if model_name not in self.models: logger.info(f"Loading model: {model_name}") try: config = MODEL_CONFIG[model_name] tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( config["model_name"], quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, ) self.models[model_name] = model self.tokenizers[model_name] = tokenizer logger.info(f"Successfully loaded model: {model_name}") except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") raise def load_vector_store(self, store_name): """Load vector store with cache invalidation""" try: # Check if we need to load a new store if self.current_store != store_name: logger.info(f"Loading new vector store: {store_name}") vector_store_path = f"vector_stores_index/{store_name}" logger.info(f"vector store path: {vector_store_path}") if not os.path.exists(vector_store_path): raise ValueError(f"Vector store not found at: {vector_store_path}") embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Load new vector store self.current_vectorstore = FAISS.load_local( vector_store_path, embeddings, allow_dangerous_deserialization=True ) self.current_store = store_name # Verify the new store self.check_vectorstore() logger.info(f"Successfully loaded vector store: {store_name}") return self.current_vectorstore except Exception as e: logger.error(f"Error loading vector store {store_name}: {str(e)}") # Reset state on error self.current_store = None self.current_vectorstore = None raise def check_vectorstore(self): """Verify current vector store content""" try: if self.current_vectorstore is None: raise ValueError("No vector store currently loaded") # Use a generic query to test retrieval sample_query = "what is this document about" docs = self.current_vectorstore.similarity_search(sample_query, k=1) logger.info(f"Vector store {self.current_store} content sample:") logger.info(f"Document content: {docs[0].page_content[:200]}...") logger.info(f"Document source: {docs[0].metadata.get('source', 'unknown')}") except Exception as e: logger.error(f"Error checking vector store: {str(e)}") raise def generate(self, message, model_name, vector_store_name, history): """Generate response using RAG""" start_time = time.time() try: # Load model and vector store self.load_model(model_name) self.load_vector_store(vector_store_name) config = MODEL_CONFIG[model_name] # Retrieve relevant context logger.info(f"Retrieving context for query: {message}") docs = self.current_vectorstore.similarity_search(message, k=3) # Log retrieved documents for debugging for i, doc in enumerate(docs): logger.info(f"Retrieved document {i + 1}:") logger.info(f"Source: {doc.metadata.get('source', 'unknown')}") logger.info(f"Content: {doc.page_content[:200]}...") context = "\n\n".join([d.page_content for d in docs]) # Format prompt prompt = config["template"].format( context=context, question=message ) logger.info(f"Generated prompt: {prompt[:200]}...") # Generate response pipe = pipeline( "text-generation", model=self.models[model_name], tokenizer=self.tokenizers[model_name], max_new_tokens=384, temperature=0.3, # Lower temperature for more focused responses top_p=0.9, repetition_penalty=1.1, do_sample=True, return_full_text=False ) response = pipe(prompt)[0]['generated_text'] # Calculate metrics elapsed_time = time.time() - start_time tokens = len(self.tokenizers[model_name].encode(response)) tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0 logger.info(f"Generated response in {elapsed_time:.2f}s") return response, elapsed_time, tokens_per_sec except Exception as e: logger.error(f"Error in generate: {str(e)}") raise # Initialize model handler model_handler = ChatModel() def chat(message, history, model_choice, vector_store_choice): """Chat interface function""" logger.info(f"Received message: {message}") logger.info(f"Using model: {model_choice}") logger.info(f"Using vector store: {vector_store_choice}") try: response, response_time, token_speed = model_handler.generate( message, model_choice, vector_store_choice, history ) # Format response with metrics and source context formatted_response = ( f"{response}\n\n" f"⏱️ Response Time: {response_time:.2f}s | " f"🚀 Speed: {token_speed:.2f} tokens/s" ) return [(message, formatted_response)] except Exception as e: logger.error(f"Error in chat: {str(e)}") error_message = f"Error: {str(e)}\n\nPlease try again or contact support if the issue persists." return [(message, error_message)] # Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("""# 🚀 Enhanced RAG Chatbot with Performance Metrics This chatbot uses Retrieval-Augmented Generation (RAG) to provide informed responses based on your documents. """) with gr.Row(): model_choice = gr.Dropdown( choices=["phi-3", "llama3-8b"], label="Select Model", value="phi-3" ) vector_store_choice = gr.Dropdown( ["llm", "scoliosis"], # Update these choices based on your vector stores value="llm", label="Knowledge Base", interactive=True ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=500) msg = gr.Textbox( label="Message", placeholder="Type your question here...", scale=4 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.ClearButton([msg, chatbot]) # Event handlers msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot) submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot) if __name__ == "__main__": demo.launch()