chat_072 / app.py
Sukuna01's picture
Update app.py
23df791 verified
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from faiss import IndexFlatL2, normalize_L2
import numpy as np
# Acknowledge the license for the model before using it
# The Hugging Face token is stored as a secret in the Space settings
# It's automatically available as an environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "google/gemma-2b-it"
# Load the model and tokenizer from Hugging Face, using the HF_TOKEN for authentication
# Set device to GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device, token=HF_TOKEN)
# Initialize the embedding model for RAG
# We use a Sentence-Transformer model for this
rag_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Create a FAISS index to store document embeddings
faiss_index = IndexFlatL2(384)
# Initialize a variable to hold the processed document content
processed_document_content = ""
# Function to generate a response from the LLM
def generate_response(question, context):
# Construct the full prompt for the LLM
prompt = f"Given the following context: {context}\n\nAnswer the question: {question}"
# Use the tokenizer to prepare the prompt for the model
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate a response using the model
outputs = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id)
# Decode the generated tokens back into a string
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# The response often includes the prompt, so we remove it to get just the answer
if prompt in response:
response = response.replace(prompt, "").strip()
return response
# Function to handle file uploads and populate the FAISS index
def process_file(file_obj):
global processed_document_content, faiss_index
if file_obj is None:
return "Please upload a file first."
# Read the content of the uploaded file
with open(file_obj.name, "r", encoding="utf-8") as f:
processed_document_content = f.read()
# Split the document into chunks (sentences in this case)
sentences = processed_document_content.split(".")
# Generate embeddings for each chunk
embeddings = rag_model.encode(sentences)
# Normalize the embeddings for FAISS
embeddings_normalized = normalize_L2(embeddings)
# Re-initialize the FAISS index and add the new embeddings
# We clear the index to handle new uploads
faiss_index = IndexFlatL2(embeddings_normalized.shape[1])
faiss_index.add(embeddings_normalized)
return f"Successfully processed file with {len(sentences)} chunks."
# Function to answer a question with RAG
def rag_answer(question):
if faiss_index.ntotal == 0:
return "No document loaded. Please upload a file first."
# Generate an embedding for the user's question
question_embedding = rag_model.encode([question])
question_embedding_normalized = normalize_L2(question_embedding)
# Search the FAISS index for the most relevant document chunk
_, indices = faiss_index.search(question_embedding_normalized, 1)
# Retrieve the relevant context (the sentence with the highest similarity)
context_sentence_index = indices[0][0]
sentences = processed_document_content.split(".")
context = sentences[context_sentence_index]
# Generate the final response using the LLM and the retrieved context
return generate_response(question, context)
# Gradio Interface setup
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# <center>Code & Data Analysis Chatbot</center>")
gr.Markdown("I'm a chatbot specialized in coding and data analysis. You can ask me questions or upload a `.csv` or `.txt` file for me to analyze!")
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(label="Upload a file for analysis (.txt or .csv)")
file_output = gr.Textbox(label="File Status")
upload_button = gr.Button("Process File")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History")
msg = gr.Textbox(label="Your message", placeholder="Ask a question...")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear Chat")
# Event handlers
upload_button.click(
fn=process_file,
inputs=file_upload,
outputs=file_output
)
submit_btn.click(
fn=lambda msg, history: (rag_answer(msg), history + [[msg, rag_answer(msg)]]),
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(
fn=lambda: (None, []),
inputs=None,
outputs=[msg, chatbot]
)
demo.launch()