Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
| import torch | |
| import time # Added for timing | |
| # New imports | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # Document processing function | |
| def process_documents(files): | |
| documents = [] | |
| for file in files: | |
| if file.name.endswith(".pdf"): | |
| loader = PyPDFLoader(file.name) | |
| elif file.name.endswith(".txt"): | |
| loader = TextLoader(file.name) | |
| documents.extend(loader.load()) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) | |
| texts = text_splitter.split_documents(documents) | |
| embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5") | |
| vectorstore = FAISS.from_documents(texts, embeddings) | |
| return vectorstore | |
| MODEL_CONFIG = { | |
| "phi-3": { | |
| "model_name": "microsoft/phi-3-mini-4k-instruct", | |
| "template": "<|user|>\n{message}<|end|>\n<|assistant|>" | |
| }, | |
| "llama3-8b": { | |
| "model_name": "NousResearch/Meta-Llama-3-8B-Instruct", | |
| "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
| {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| """ | |
| } | |
| } | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| class ChatModel: | |
| def __init__(self): | |
| self.models = {} | |
| self.tokenizers = {} | |
| def load_model(self, model_name): | |
| if model_name not in self.models: | |
| config = MODEL_CONFIG[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config["model_name"], | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| self.models[model_name] = model | |
| self.tokenizers[model_name] = tokenizer | |
| def generate(self, message, model_name, history, vectorstore=None): | |
| # RAG context retrieval | |
| if vectorstore: | |
| docs = vectorstore.similarity_search(message, k=3) | |
| context = "\n".join([d.page_content for d in docs]) | |
| message = f"Context: {context}\n\nQuestion: {message}" | |
| start_time = time.time() # Start timing | |
| self.load_model(model_name) | |
| config = MODEL_CONFIG[model_name] | |
| # Format prompt | |
| prompt = config["template"].format(message=message) | |
| # Create pipeline | |
| pipe = pipeline( | |
| "text-generation", | |
| model=self.models[model_name], | |
| tokenizer=self.tokenizers[model_name], | |
| max_new_tokens=384, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| return_full_text=False | |
| ) | |
| response = pipe(prompt)[0]['generated_text'] | |
| # Calculate metrics | |
| elapsed_time = time.time() - start_time | |
| tokens = len(self.tokenizers[model_name].encode(response)) | |
| tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0 | |
| return response, elapsed_time, tokens_per_sec | |
| model_handler = ChatModel() | |
| def chat(message, history, model_choice): | |
| try: | |
| response, response_time, token_speed = model_handler.generate(message, model_choice, history) | |
| formatted_response = f"{response}\n\nโฑ๏ธ Response Time: {response_time:.2f}s | ๐ Speed: {token_speed:.2f} tokens/s" | |
| return [(message, formatted_response)] | |
| except Exception as e: | |
| return [(message, f"Error: {str(e)}")] | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ LLM Chatbot with RAG & Performance Metrics") | |
| # Add document upload section | |
| with gr.Row(): | |
| file_output = gr.File(label="Upload Documents", file_count="multiple", | |
| file_types=[".pdf", ".txt"], max_size=10) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=["phi-3", "llama3-8b"], | |
| label="Select Model", | |
| value="phi-3" | |
| ) | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Message", placeholder="Type here...") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.ClearButton([msg, chatbot]) | |
| msg.submit(chat, [msg, chatbot, model_choice], chatbot) | |
| submit_btn.click(chat, [msg, chatbot, model_choice], chatbot) | |
| demo.launch() |