example_test / app.py
Wenye He
Update app.py
2470a68 verified
raw
history blame
7.14 kB
# app.py
import gradio as gr
import torch
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# Configuration
MODEL_CONFIG = {
"phi-3": {
"model_name": "microsoft/phi-3-mini-4k-instruct",
"template": "<|user|>\n{message}<|end|>\n<|assistant|>"
},
"llama3-8b": {
"model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
"template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
}
}
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
class ChatModel:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.vectorstore = {}
def load_model(self, model_name):
if model_name not in self.models:
config = MODEL_CONFIG[model_name]
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
config["model_name"],
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
)
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
def load_vector_store(self, store_name):
"""Cache vector stores in memory"""
if store_name not in self.vectorstore:
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5"
)
self.vectorstore[store_name] = FAISS.load_local(
f"vector_store/{store_name}",
embeddings,
allow_dangerous_deserialization=True
)
return self.vectorstore[store_name]
def process_documents(self, files, progress=gr.Progress()):
"""Process uploaded documents into vector embeddings"""
try:
progress(0, desc="Starting document processing")
documents = []
# Load documents
for file_path in progress.tqdm(files, desc="Loading files"):
if file_path.endswith(".pdf"):
loader = PyPDFLoader(file_path)
elif file_path.endswith(".txt"):
loader = TextLoader(file_path)
else:
continue
documents.extend(loader.load())
# Split documents
progress(0.3, desc="Processing documents")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512,
chunk_overlap=50
)
texts = text_splitter.split_documents(documents)
# Create embeddings
progress(0.6, desc="Generating embeddings")
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5"
)
# Create vector store
progress(0.8, desc="Building vector database")
self.vectorstore = FAISS.from_documents(texts, embeddings)
return "βœ… Documents processed successfully! Ready for queries."
except Exception as e:
return f"❌ Error processing documents: {str(e)}"
def generate(self, message, model_name, vector_store_name, history):
start_time = time.time()
self.load_model(model_name)
vectorstore = self.load_vector_store(vector_store_name)
config = MODEL_CONFIG[model_name]
# Retrieve relevant context
context = ""
# if vectorstore:
docs = vectorstore.similarity_search(message, k=3)
context = "\n\n".join([d.page_content for d in docs])
# Format prompt with context
prompt = config["template"].format(
message=f"Context:\n{context}\n\nQuestion: {message}"
)
# Generate response
pipe = pipeline(
"text-generation",
model=self.models[model_name],
tokenizer=self.tokenizers[model_name],
max_new_tokens=384,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
return_full_text=False
)
response = pipe(prompt)[0]['generated_text']
# Calculate metrics
elapsed_time = time.time() - start_time
tokens = len(self.tokenizers[model_name].encode(response))
tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
return response, elapsed_time, tokens_per_sec
# Initialize model handler
model_handler = ChatModel()
def chat(message, history, model_choice, vector_store_choice):
print("vector_store_choice: ", vector_store_choice)
try:
response, response_time, token_speed = model_handler.generate(message, model_choice, vector_store_choice, history)
formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | πŸš€ Speed: {token_speed:.2f} tokens/s"
return [(message, formatted_response)]
except Exception as e:
return [(message, f"Error: {str(e)}")]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ LLM Chatbot with RAG & Performance Metrics")
with gr.Row():
model_choice = gr.Dropdown(
choices=["phi-3", "llama3-8b"],
label="Select Model",
value="phi-3"
)
vector_store_choice = gr.Dropdown(
["llm", "scoliosis"],
value="scoliosis",
label="Knowledge Base"
)
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(
label="Upload Documents (PDF/TXT)",
file_count="multiple",
file_types=[".pdf", ".txt"],
type="filepath"
)
status = gr.Textbox(label="Processing Status", interactive=False)
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Message", placeholder="Type your question here...")
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.ClearButton([msg, chatbot, file_upload])
# Event handlers
file_upload.upload(
fn=model_handler.process_documents,
inputs=file_upload,
outputs=status,
show_progress="full"
)
msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
demo.launch()