import gradio as gr import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from faiss import IndexFlatL2, normalize_L2 import numpy as np # Acknowledge the license for the model before using it # The Hugging Face token is stored as a secret in the Space settings # It's automatically available as an environment variable HF_TOKEN = os.getenv("HF_TOKEN") MODEL_ID = "google/gemma-2b-it" # Load the model and tokenizer from Hugging Face, using the HF_TOKEN for authentication # Set device to GPU if available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device, token=HF_TOKEN) # Initialize the embedding model for RAG # We use a Sentence-Transformer model for this rag_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # Create a FAISS index to store document embeddings faiss_index = IndexFlatL2(384) # Initialize a variable to hold the processed document content processed_document_content = "" # Function to generate a response from the LLM def generate_response(question, context): # Construct the full prompt for the LLM prompt = f"Given the following context: {context}\n\nAnswer the question: {question}" # Use the tokenizer to prepare the prompt for the model inputs = tokenizer(prompt, return_tensors="pt").to(device) # Generate a response using the model outputs = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id) # Decode the generated tokens back into a string response = tokenizer.decode(outputs[0], skip_special_tokens=True) # The response often includes the prompt, so we remove it to get just the answer if prompt in response: response = response.replace(prompt, "").strip() return response # Function to handle file uploads and populate the FAISS index def process_file(file_obj): global processed_document_content, faiss_index if file_obj is None: return "Please upload a file first." # Read the content of the uploaded file with open(file_obj.name, "r", encoding="utf-8") as f: processed_document_content = f.read() # Split the document into chunks (sentences in this case) sentences = processed_document_content.split(".") # Generate embeddings for each chunk embeddings = rag_model.encode(sentences) # Normalize the embeddings for FAISS embeddings_normalized = normalize_L2(embeddings) # Re-initialize the FAISS index and add the new embeddings # We clear the index to handle new uploads faiss_index = IndexFlatL2(embeddings_normalized.shape[1]) faiss_index.add(embeddings_normalized) return f"Successfully processed file with {len(sentences)} chunks." # Function to answer a question with RAG def rag_answer(question): if faiss_index.ntotal == 0: return "No document loaded. Please upload a file first." # Generate an embedding for the user's question question_embedding = rag_model.encode([question]) question_embedding_normalized = normalize_L2(question_embedding) # Search the FAISS index for the most relevant document chunk _, indices = faiss_index.search(question_embedding_normalized, 1) # Retrieve the relevant context (the sentence with the highest similarity) context_sentence_index = indices[0][0] sentences = processed_document_content.split(".") context = sentences[context_sentence_index] # Generate the final response using the LLM and the retrieved context return generate_response(question, context) # Gradio Interface setup with gr.Blocks(theme="soft") as demo: gr.Markdown("#
Code & Data Analysis Chatbot
") gr.Markdown("I'm a chatbot specialized in coding and data analysis. You can ask me questions or upload a `.csv` or `.txt` file for me to analyze!") with gr.Row(): with gr.Column(scale=1): file_upload = gr.File(label="Upload a file for analysis (.txt or .csv)") file_output = gr.Textbox(label="File Status") upload_button = gr.Button("Process File") with gr.Column(scale=2): chatbot = gr.Chatbot(label="Chat History") msg = gr.Textbox(label="Your message", placeholder="Ask a question...") with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear Chat") # Event handlers upload_button.click( fn=process_file, inputs=file_upload, outputs=file_output ) submit_btn.click( fn=lambda msg, history: (rag_answer(msg), history + [[msg, rag_answer(msg)]]), inputs=[msg, chatbot], outputs=[msg, chatbot] ) clear_btn.click( fn=lambda: (None, []), inputs=None, outputs=[msg, chatbot] ) demo.launch()