Spaces:
Paused
Paused
| # app.py | |
| import gradio as gr | |
| import torch | |
| import time | |
| import os | |
| import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_CONFIG = { | |
| "phi-3": { | |
| "model_name": "microsoft/phi-3-mini-4k-instruct", | |
| "template": """<|user|> | |
| Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly. | |
| Context: | |
| {context} | |
| Question: {question}<|end|> | |
| <|assistant|> | |
| Based on the provided context, I'll answer your question:""" | |
| }, | |
| "llama3-8b": { | |
| "model_name": "NousResearch/Meta-Llama-3-8B-Instruct", | |
| "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
| Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly. | |
| Context: | |
| {context} | |
| Question: {question}<|eot_id|> | |
| <|start_header_id|>assistant<|end_header_id|> | |
| Based on the provided context, I'll answer your question:""" | |
| } | |
| } | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| class ChatModel: | |
| def __init__(self): | |
| self.models = {} | |
| self.tokenizers = {} | |
| self.current_store = None | |
| self.current_vectorstore = None | |
| # Use the same embedding model as in vector store creation | |
| # self.embeddings = HuggingFaceEmbeddings( | |
| # model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| # ) | |
| def load_model(self, model_name): | |
| """Load and cache the model and tokenizer""" | |
| if model_name not in self.models: | |
| logger.info(f"Loading model: {model_name}") | |
| try: | |
| config = MODEL_CONFIG[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config["model_name"], | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| self.models[model_name] = model | |
| self.tokenizers[model_name] = tokenizer | |
| logger.info(f"Successfully loaded model: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Error loading model {model_name}: {str(e)}") | |
| raise | |
| def load_vector_store(self, store_name): | |
| """Load vector store with cache invalidation""" | |
| try: | |
| # Check if we need to load a new store | |
| if self.current_store != store_name: | |
| logger.info(f"Loading new vector store: {store_name}") | |
| vector_store_path = f"vector_stores_index/{store_name}" | |
| logger.info(f"vector store path: {vector_store_path}") | |
| if not os.path.exists(vector_store_path): | |
| raise ValueError(f"Vector store not found at: {vector_store_path}") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # Load new vector store | |
| self.current_vectorstore = FAISS.load_local( | |
| vector_store_path, | |
| embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| self.current_store = store_name | |
| # Verify the new store | |
| self.check_vectorstore() | |
| logger.info(f"Successfully loaded vector store: {store_name}") | |
| return self.current_vectorstore | |
| except Exception as e: | |
| logger.error(f"Error loading vector store {store_name}: {str(e)}") | |
| # Reset state on error | |
| self.current_store = None | |
| self.current_vectorstore = None | |
| raise | |
| def check_vectorstore(self): | |
| """Verify current vector store content""" | |
| try: | |
| if self.current_vectorstore is None: | |
| raise ValueError("No vector store currently loaded") | |
| # Use a generic query to test retrieval | |
| sample_query = "what is this document about" | |
| docs = self.current_vectorstore.similarity_search(sample_query, k=1) | |
| logger.info(f"Vector store {self.current_store} content sample:") | |
| logger.info(f"Document content: {docs[0].page_content[:200]}...") | |
| logger.info(f"Document source: {docs[0].metadata.get('source', 'unknown')}") | |
| except Exception as e: | |
| logger.error(f"Error checking vector store: {str(e)}") | |
| raise | |
| def generate(self, message, model_name, vector_store_name, history): | |
| """Generate response using RAG""" | |
| start_time = time.time() | |
| try: | |
| # Load model and vector store | |
| self.load_model(model_name) | |
| self.load_vector_store(vector_store_name) | |
| config = MODEL_CONFIG[model_name] | |
| # Retrieve relevant context | |
| logger.info(f"Retrieving context for query: {message}") | |
| docs = self.current_vectorstore.similarity_search(message, k=3) | |
| # Log retrieved documents for debugging | |
| for i, doc in enumerate(docs): | |
| logger.info(f"Retrieved document {i + 1}:") | |
| logger.info(f"Source: {doc.metadata.get('source', 'unknown')}") | |
| logger.info(f"Content: {doc.page_content[:200]}...") | |
| context = "\n\n".join([d.page_content for d in docs]) | |
| # Format prompt | |
| prompt = config["template"].format( | |
| context=context, | |
| question=message | |
| ) | |
| logger.info(f"Generated prompt: {prompt[:200]}...") | |
| # Generate response | |
| pipe = pipeline( | |
| "text-generation", | |
| model=self.models[model_name], | |
| tokenizer=self.tokenizers[model_name], | |
| max_new_tokens=384, | |
| temperature=0.3, # Lower temperature for more focused responses | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| return_full_text=False | |
| ) | |
| response = pipe(prompt)[0]['generated_text'] | |
| # Calculate metrics | |
| elapsed_time = time.time() - start_time | |
| tokens = len(self.tokenizers[model_name].encode(response)) | |
| tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0 | |
| logger.info(f"Generated response in {elapsed_time:.2f}s") | |
| return response, elapsed_time, tokens_per_sec | |
| except Exception as e: | |
| logger.error(f"Error in generate: {str(e)}") | |
| raise | |
| # Initialize model handler | |
| model_handler = ChatModel() | |
| def chat(message, history, model_choice, vector_store_choice): | |
| """Chat interface function""" | |
| logger.info(f"Received message: {message}") | |
| logger.info(f"Using model: {model_choice}") | |
| logger.info(f"Using vector store: {vector_store_choice}") | |
| try: | |
| response, response_time, token_speed = model_handler.generate( | |
| message, | |
| model_choice, | |
| vector_store_choice, | |
| history | |
| ) | |
| # Format response with metrics and source context | |
| formatted_response = ( | |
| f"{response}\n\n" | |
| f"⏱️ Response Time: {response_time:.2f}s | " | |
| f"🚀 Speed: {token_speed:.2f} tokens/s" | |
| ) | |
| return [(message, formatted_response)] | |
| except Exception as e: | |
| logger.error(f"Error in chat: {str(e)}") | |
| error_message = f"Error: {str(e)}\n\nPlease try again or contact support if the issue persists." | |
| return [(message, error_message)] | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("""# 🚀 Enhanced RAG Chatbot with Performance Metrics | |
| This chatbot uses Retrieval-Augmented Generation (RAG) to provide informed responses based on your documents. | |
| """) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=["phi-3", "llama3-8b"], | |
| label="Select Model", | |
| value="phi-3" | |
| ) | |
| vector_store_choice = gr.Dropdown( | |
| ["llm", "scoliosis"], # Update these choices based on your vector stores | |
| value="llm", | |
| label="Knowledge Base", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=500) | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your question here...", | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.ClearButton([msg, chatbot]) | |
| # Event handlers | |
| msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot) | |
| submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot) | |
| if __name__ == "__main__": | |
| demo.launch() |