example_test / app.py
Wenye He
Update app.py
53bce17 verified
raw
history blame
4.92 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
import time # Added for timing
# New imports
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# Document processing function
def process_documents(files):
documents = []
for file in files:
if file.name.endswith(".pdf"):
loader = PyPDFLoader(file.name)
elif file.name.endswith(".txt"):
loader = TextLoader(file.name)
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
texts = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
vectorstore = FAISS.from_documents(texts, embeddings)
return vectorstore
MODEL_CONFIG = {
"phi-3": {
"model_name": "microsoft/phi-3-mini-4k-instruct",
"template": "<|user|>\n{message}<|end|>\n<|assistant|>"
},
"llama3-8b": {
"model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
"template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
}
}
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
class ChatModel:
def __init__(self):
self.models = {}
self.tokenizers = {}
def load_model(self, model_name):
if model_name not in self.models:
config = MODEL_CONFIG[model_name]
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
config["model_name"],
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
)
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
def generate(self, message, model_name, history, vectorstore=None):
# RAG context retrieval
if vectorstore:
docs = vectorstore.similarity_search(message, k=3)
context = "\n".join([d.page_content for d in docs])
message = f"Context: {context}\n\nQuestion: {message}"
start_time = time.time() # Start timing
self.load_model(model_name)
config = MODEL_CONFIG[model_name]
# Format prompt
prompt = config["template"].format(message=message)
# Create pipeline
pipe = pipeline(
"text-generation",
model=self.models[model_name],
tokenizer=self.tokenizers[model_name],
max_new_tokens=384,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
return_full_text=False
)
response = pipe(prompt)[0]['generated_text']
# Calculate metrics
elapsed_time = time.time() - start_time
tokens = len(self.tokenizers[model_name].encode(response))
tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
return response, elapsed_time, tokens_per_sec
model_handler = ChatModel()
def chat(message, history, model_choice):
try:
response, response_time, token_speed = model_handler.generate(message, model_choice, history)
formatted_response = f"{response}\n\nโฑ๏ธ Response Time: {response_time:.2f}s | ๐Ÿš€ Speed: {token_speed:.2f} tokens/s"
return [(message, formatted_response)]
except Exception as e:
return [(message, f"Error: {str(e)}")]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿš€ LLM Chatbot with RAG & Performance Metrics")
# Add document upload section
with gr.Row():
file_output = gr.File(label="Upload Documents", file_count="multiple",
file_types=[".pdf", ".txt"], max_size=10)
with gr.Row():
model_choice = gr.Dropdown(
choices=["phi-3", "llama3-8b"],
label="Select Model",
value="phi-3"
)
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Message", placeholder="Type here...")
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.ClearButton([msg, chatbot])
msg.submit(chat, [msg, chatbot, model_choice], chatbot)
submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
demo.launch()