Wenye He commited on
Commit
262fcc6
·
verified ·
1 Parent(s): 78522bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -65
app.py CHANGED
@@ -1,43 +1,15 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
3
  import torch
4
- import time # Added for timing
5
- # New imports
 
6
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import FAISS
10
 
11
- # Document processing function
12
- def process_documents(files):
13
- """Process PDF/TXT files into vector embeddings"""
14
- documents = []
15
- for file_path in files:
16
- if file_path.endswith(".pdf"):
17
- loader = PyPDFLoader(file_path)
18
- elif file_path.endswith(".txt"):
19
- loader = TextLoader(file_path)
20
- else:
21
- continue
22
- documents.extend(loader.load())
23
-
24
- # Split documents into chunks
25
- text_splitter = RecursiveCharacterTextSplitter(
26
- chunk_size=512,
27
- chunk_overlap=50
28
- )
29
- texts = text_splitter.split_documents(documents)
30
-
31
- # Create embeddings
32
- embeddings = HuggingFaceEmbeddings(
33
- model_name="BAAI/bge-small-en-v1.5"
34
- )
35
-
36
- # Create vector store
37
- vectorstore = FAISS.from_documents(texts, embeddings)
38
- return vectorstore
39
-
40
-
41
  MODEL_CONFIG = {
42
  "phi-3": {
43
  "model_name": "microsoft/phi-3-mini-4k-instruct",
@@ -46,8 +18,7 @@ MODEL_CONFIG = {
46
  "llama3-8b": {
47
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
48
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
49
- {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
50
- """
51
  }
52
  }
53
 
@@ -62,33 +33,78 @@ class ChatModel:
62
  def __init__(self):
63
  self.models = {}
64
  self.tokenizers = {}
65
- self.vectorstore = None # Add vectorstore reference
66
 
67
- # Add this new method
68
- def update_vectorstore(self, files):
69
- """Process uploaded files and update vectorstore"""
70
- if files:
71
- self.vectorstore = process_documents(files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Modify existing generate method
74
  def generate(self, message, model_name, history):
75
  start_time = time.time()
 
 
76
 
77
  # Retrieve relevant context
78
  context = ""
79
  if self.vectorstore:
80
  docs = self.vectorstore.similarity_search(message, k=3)
81
- context = "\n".join([d.page_content for d in docs])
82
 
83
- self.load_model(model_name)
84
- config = MODEL_CONFIG[model_name]
85
-
86
- # Update prompt with context
87
  prompt = config["template"].format(
88
- message=f"Context: {context}\n\nQuestion: {message}"
89
  )
90
 
91
- # Create pipeline
92
  pipe = pipeline(
93
  "text-generation",
94
  model=self.models[model_name],
@@ -110,6 +126,7 @@ class ChatModel:
110
 
111
  return response, elapsed_time, tokens_per_sec
112
 
 
113
  model_handler = ChatModel()
114
 
115
  def chat(message, history, model_choice):
@@ -123,33 +140,39 @@ def chat(message, history, model_choice):
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
  gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
125
 
126
- # Add document upload section
127
- with gr.Row():
128
- file_upload = gr.File(
129
- label="Upload Documents (PDF/TXT)",
130
- file_count="multiple",
131
- file_types=[".pdf", ".txt"],
132
- type="filepath"
133
- )
134
  with gr.Row():
135
  model_choice = gr.Dropdown(
136
  choices=["phi-3", "llama3-8b"],
137
  label="Select Model",
138
  value="phi-3"
139
  )
140
- chatbot = gr.Chatbot(height=400)
141
- msg = gr.Textbox(label="Message", placeholder="Type here...")
 
 
 
 
 
 
 
 
 
 
 
 
142
  with gr.Row():
143
  submit_btn = gr.Button("Send", variant="primary")
144
- clear_btn = gr.ClearButton([msg, chatbot])
145
-
146
- msg.submit(chat, [msg, chatbot, model_choice], chatbot)
147
- submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
148
 
 
149
  file_upload.upload(
150
- fn=model_handler.update_vectorstore,
151
  inputs=file_upload,
152
- outputs=None
 
153
  )
 
 
 
154
 
155
  demo.launch()
 
1
+ # app.py
2
  import gradio as gr
 
3
  import torch
4
+ import time
5
+ import os
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
7
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import FAISS
11
 
12
+ # Configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  MODEL_CONFIG = {
14
  "phi-3": {
15
  "model_name": "microsoft/phi-3-mini-4k-instruct",
 
18
  "llama3-8b": {
19
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
20
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
21
+ {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
 
22
  }
23
  }
24
 
 
33
  def __init__(self):
34
  self.models = {}
35
  self.tokenizers = {}
36
+ self.vectorstore = None
37
 
38
+ def load_model(self, model_name):
39
+ if model_name not in self.models:
40
+ config = MODEL_CONFIG[model_name]
41
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ config["model_name"],
45
+ quantization_config=bnb_config,
46
+ device_map="auto",
47
+ torch_dtype=torch.float16,
48
+ )
49
+ self.models[model_name] = model
50
+ self.tokenizers[model_name] = tokenizer
51
+
52
+ def process_documents(self, files, progress=gr.Progress()):
53
+ """Process uploaded documents into vector embeddings"""
54
+ try:
55
+ progress(0, desc="Starting document processing")
56
+ documents = []
57
+
58
+ # Load documents
59
+ for file_path in progress.tqdm(files, desc="Loading files"):
60
+ if file_path.endswith(".pdf"):
61
+ loader = PyPDFLoader(file_path)
62
+ elif file_path.endswith(".txt"):
63
+ loader = TextLoader(file_path)
64
+ else:
65
+ continue
66
+ documents.extend(loader.load())
67
+
68
+ # Split documents
69
+ progress(0.3, desc="Processing documents")
70
+ text_splitter = RecursiveCharacterTextSplitter(
71
+ chunk_size=512,
72
+ chunk_overlap=50
73
+ )
74
+ texts = text_splitter.split_documents(documents)
75
+
76
+ # Create embeddings
77
+ progress(0.6, desc="Generating embeddings")
78
+ embeddings = HuggingFaceEmbeddings(
79
+ model_name="BAAI/bge-small-en-v1.5"
80
+ )
81
+
82
+ # Create vector store
83
+ progress(0.8, desc="Building vector database")
84
+ self.vectorstore = FAISS.from_documents(texts, embeddings)
85
+
86
+ return "✅ Documents processed successfully! Ready for queries."
87
+
88
+ except Exception as e:
89
+ return f"❌ Error processing documents: {str(e)}"
90
 
 
91
  def generate(self, message, model_name, history):
92
  start_time = time.time()
93
+ self.load_model(model_name)
94
+ config = MODEL_CONFIG[model_name]
95
 
96
  # Retrieve relevant context
97
  context = ""
98
  if self.vectorstore:
99
  docs = self.vectorstore.similarity_search(message, k=3)
100
+ context = "\n\n".join([d.page_content for d in docs])
101
 
102
+ # Format prompt with context
 
 
 
103
  prompt = config["template"].format(
104
+ message=f"Context:\n{context}\n\nQuestion: {message}"
105
  )
106
 
107
+ # Generate response
108
  pipe = pipeline(
109
  "text-generation",
110
  model=self.models[model_name],
 
126
 
127
  return response, elapsed_time, tokens_per_sec
128
 
129
+ # Initialize model handler
130
  model_handler = ChatModel()
131
 
132
  def chat(message, history, model_choice):
 
140
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
141
  gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
142
 
 
 
 
 
 
 
 
 
143
  with gr.Row():
144
  model_choice = gr.Dropdown(
145
  choices=["phi-3", "llama3-8b"],
146
  label="Select Model",
147
  value="phi-3"
148
  )
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=1):
152
+ file_upload = gr.File(
153
+ label="Upload Documents (PDF/TXT)",
154
+ file_count="multiple",
155
+ file_types=[".pdf", ".txt"],
156
+ type="filepath"
157
+ )
158
+ status = gr.Textbox(label="Processing Status", interactive=False)
159
+ with gr.Column(scale=3):
160
+ chatbot = gr.Chatbot(height=500)
161
+ msg = gr.Textbox(label="Message", placeholder="Type your question here...")
162
+
163
  with gr.Row():
164
  submit_btn = gr.Button("Send", variant="primary")
165
+ clear_btn = gr.ClearButton([msg, chatbot, file_upload])
 
 
 
166
 
167
+ # Event handlers
168
  file_upload.upload(
169
+ fn=model_handler.process_documents,
170
  inputs=file_upload,
171
+ outputs=status,
172
+ show_progress="full"
173
  )
174
+
175
+ msg.submit(chat, [msg, chatbot, model_choice], chatbot)
176
+ submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
177
 
178
  demo.launch()