File size: 5,032 Bytes
12c5daf 23df791 12c5daf 23df791 12c5daf 23df791 12c5daf 23df791 12c5daf 23df791 12c5daf 23df791 12c5daf 23df791 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from faiss import IndexFlatL2, normalize_L2
import numpy as np
# Acknowledge the license for the model before using it
# The Hugging Face token is stored as a secret in the Space settings
# It's automatically available as an environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "google/gemma-2b-it"
# Load the model and tokenizer from Hugging Face, using the HF_TOKEN for authentication
# Set device to GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device, token=HF_TOKEN)
# Initialize the embedding model for RAG
# We use a Sentence-Transformer model for this
rag_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Create a FAISS index to store document embeddings
faiss_index = IndexFlatL2(384)
# Initialize a variable to hold the processed document content
processed_document_content = ""
# Function to generate a response from the LLM
def generate_response(question, context):
# Construct the full prompt for the LLM
prompt = f"Given the following context: {context}\n\nAnswer the question: {question}"
# Use the tokenizer to prepare the prompt for the model
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate a response using the model
outputs = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id)
# Decode the generated tokens back into a string
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# The response often includes the prompt, so we remove it to get just the answer
if prompt in response:
response = response.replace(prompt, "").strip()
return response
# Function to handle file uploads and populate the FAISS index
def process_file(file_obj):
global processed_document_content, faiss_index
if file_obj is None:
return "Please upload a file first."
# Read the content of the uploaded file
with open(file_obj.name, "r", encoding="utf-8") as f:
processed_document_content = f.read()
# Split the document into chunks (sentences in this case)
sentences = processed_document_content.split(".")
# Generate embeddings for each chunk
embeddings = rag_model.encode(sentences)
# Normalize the embeddings for FAISS
embeddings_normalized = normalize_L2(embeddings)
# Re-initialize the FAISS index and add the new embeddings
# We clear the index to handle new uploads
faiss_index = IndexFlatL2(embeddings_normalized.shape[1])
faiss_index.add(embeddings_normalized)
return f"Successfully processed file with {len(sentences)} chunks."
# Function to answer a question with RAG
def rag_answer(question):
if faiss_index.ntotal == 0:
return "No document loaded. Please upload a file first."
# Generate an embedding for the user's question
question_embedding = rag_model.encode([question])
question_embedding_normalized = normalize_L2(question_embedding)
# Search the FAISS index for the most relevant document chunk
_, indices = faiss_index.search(question_embedding_normalized, 1)
# Retrieve the relevant context (the sentence with the highest similarity)
context_sentence_index = indices[0][0]
sentences = processed_document_content.split(".")
context = sentences[context_sentence_index]
# Generate the final response using the LLM and the retrieved context
return generate_response(question, context)
# Gradio Interface setup
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# <center>Code & Data Analysis Chatbot</center>")
gr.Markdown("I'm a chatbot specialized in coding and data analysis. You can ask me questions or upload a `.csv` or `.txt` file for me to analyze!")
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(label="Upload a file for analysis (.txt or .csv)")
file_output = gr.Textbox(label="File Status")
upload_button = gr.Button("Process File")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History")
msg = gr.Textbox(label="Your message", placeholder="Ask a question...")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear Chat")
# Event handlers
upload_button.click(
fn=process_file,
inputs=file_upload,
outputs=file_output
)
submit_btn.click(
fn=lambda msg, history: (rag_answer(msg), history + [[msg, rag_answer(msg)]]),
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(
fn=lambda: (None, []),
inputs=None,
outputs=[msg, chatbot]
)
demo.launch()
|