File size: 4,919 Bytes
84a76f5
dd8d3db
9b42973
ec86a60
53bce17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5276429
 
dd93054
 
 
5276429
dd93054
97128d6
dd93054
 
 
5276429
 
 
f937954
 
 
 
 
 
 
5276429
 
dd93054
 
 
5276429
dd93054
5276429
dd93054
f937954
 
dd93054
f937954
5276429
f937954
dd93054
c5b2064
5276429
f937954
 
 
eefebc4
53bce17
 
 
 
 
 
 
ec86a60
c31bf37
 
 
f937954
c31bf37
 
dd8d3db
 
 
 
 
 
 
 
 
 
 
 
1cb71a2
dd8d3db
ec86a60
 
 
 
 
 
 
c31bf37
 
 
 
4c5f924
ec86a60
 
 
4c5f924
f937954
c31bf37
 
53bce17
 
 
 
 
 
c31bf37
 
 
 
 
 
 
 
 
 
 
 
 
 
5276429
c31bf37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
import time  # Added for timing
# New imports
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

# Document processing function
def process_documents(files):
    documents = []
    for file in files:
        if file.name.endswith(".pdf"):
            loader = PyPDFLoader(file.name)
        elif file.name.endswith(".txt"):
            loader = TextLoader(file.name)
        documents.extend(loader.load())
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
    texts = text_splitter.split_documents(documents)
    
    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
    vectorstore = FAISS.from_documents(texts, embeddings)
    return vectorstore


MODEL_CONFIG = {
    "phi-3": {
        "model_name": "microsoft/phi-3-mini-4k-instruct",
        "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
    },
    "llama3-8b": {
        "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
        "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
    }
}

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

class ChatModel:
    def __init__(self):
        self.models = {}
        self.tokenizers = {}
        
    def load_model(self, model_name):
        if model_name not in self.models:
            config = MODEL_CONFIG[model_name]
            
            tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
            tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(
                config["model_name"],
                quantization_config=bnb_config,
                device_map="auto",
                torch_dtype=torch.float16,
            )
            
            self.models[model_name] = model
            self.tokenizers[model_name] = tokenizer

    def generate(self, message, model_name, history, vectorstore=None):
        # RAG context retrieval
        if vectorstore:
            docs = vectorstore.similarity_search(message, k=3)
            context = "\n".join([d.page_content for d in docs])
            message = f"Context: {context}\n\nQuestion: {message}"

        start_time = time.time()  # Start timing
        self.load_model(model_name)
        config = MODEL_CONFIG[model_name]
        
        # Format prompt
        prompt = config["template"].format(message=message)
        
        # Create pipeline
        pipe = pipeline(
            "text-generation",
            model=self.models[model_name],
            tokenizer=self.tokenizers[model_name],
            max_new_tokens=384,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True,
            return_full_text=False
        )
        
        response = pipe(prompt)[0]['generated_text']
        
        # Calculate metrics
        elapsed_time = time.time() - start_time
        tokens = len(self.tokenizers[model_name].encode(response))
        tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
        
        return response, elapsed_time, tokens_per_sec

model_handler = ChatModel()

def chat(message, history, model_choice):
    try:
        response, response_time, token_speed = model_handler.generate(message, model_choice, history)
        formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
        return [(message, formatted_response)]
    except Exception as e:
        return [(message, f"Error: {str(e)}")]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
    
    # Add document upload section
    with gr.Row():
        file_output = gr.File(label="Upload Documents", file_count="multiple", 
                            file_types=[".pdf", ".txt"], max_size=10)
    with gr.Row():
        model_choice = gr.Dropdown(
            choices=["phi-3", "llama3-8b"],
            label="Select Model",
            value="phi-3"
        )
    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(label="Message", placeholder="Type here...")
    with gr.Row():
        submit_btn = gr.Button("Send", variant="primary")
        clear_btn = gr.ClearButton([msg, chatbot])

    msg.submit(chat, [msg, chatbot, model_choice], chatbot)
    submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)

demo.launch()