Wenye He commited on
Commit
68053b5
·
verified ·
1 Parent(s): 46a9cde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -120
app.py CHANGED
@@ -1,144 +1,201 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
- from langchain.chains import ConversationalRetrievalChain
4
- from langchain.embeddings import HuggingFaceEmbeddings
5
- from langchain.vectorstores import FAISS
6
- from langchain.memory import ConversationBufferMemory
7
- from langchain.llms import HuggingFacePipeline
8
  import torch
 
 
 
 
 
 
 
9
 
10
- # Model Configuration
11
  MODEL_CONFIG = {
12
- "phi-3-mini": {
13
- "name": "microsoft/phi-3-mini-4k-instruct",
14
- "max_tokens": 1024,
15
- "temperature": 0.8
16
  },
17
- "Mistral-7B": {
18
- "name": "mistralai/Mistral-7B-Instruct-v0.3",
19
- "max_tokens": 512,
20
- "temperature": 0.7
21
  }
22
  }
23
 
24
- # Cache Stores
25
- vector_store_cache = {}
26
- model_pipeline_cache = {}
27
- embedder = HuggingFaceEmbeddings()
28
-
29
- def load_vector_store(store_name):
30
- """Cache vector stores in memory"""
31
- if store_name not in vector_store_cache:
32
- vector_store_cache[store_name] = FAISS.load_local(
33
- f"vector_store/{store_name}",
34
- embedder
35
- )
36
- return vector_store_cache[store_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def get_model_pipeline(model_choice):
39
- """Cache model pipelines in memory"""
40
- if model_choice not in model_pipeline_cache:
41
- cfg = MODEL_CONFIG[model_choice]
42
 
43
- tokenizer = AutoTokenizer.from_pretrained(cfg["name"])
44
- model = AutoModelForCausalLM.from_pretrained(
45
- cfg["name"],
46
- device_map="auto",
47
- torch_dtype="auto" if "phi-3" in model_choice else torch.float16
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- model_pipeline_cache[model_choice] = pipeline(
51
- "text-generation",
52
- model=model,
53
- tokenizer=tokenizer,
54
- max_new_tokens=cfg["max_tokens"],
55
- temperature=cfg["temperature"]
56
- )
57
- return model_pipeline_cache[model_choice]
58
 
59
- class SessionChain:
60
- """Per-session chain manager with memory"""
61
- def __init__(self):
62
- self.current_model = None
63
- self.current_vector_store = None
64
- self.chain = None
65
-
66
- def get_chain(self, model_choice, vector_store_name):
67
- """Get or create chain with proper configuration"""
68
- if self.current_model != model_choice or self.current_vector_store != vector_store_name:
69
- self._create_new_chain(model_choice, vector_store_name)
70
- return self.chain
71
-
72
- def _create_new_chain(self, model_choice, vector_store_name):
73
- """Create new chain with updated configuration"""
74
- vector_store = load_vector_store(vector_store_name)
75
- pipe = get_model_pipeline(model_choice)
76
 
77
- self.chain = ConversationalRetrievalChain.from_llm(
78
- llm=HuggingFacePipeline(pipeline=pipe),
79
- retriever=vector_store.as_retriever(),
80
- memory=ConversationBufferMemory(),
81
- verbose=False
82
  )
83
- self.current_model = model_choice
84
- self.current_vector_store = vector_store_name
85
-
86
- def respond(message, history, model_choice, vector_store, session_state):
87
- """Handle message with cached resources and session chain"""
88
- # Initialize session chain if not exists
89
- if session_state is None:
90
- session_state = SessionChain()
91
-
92
- # Get the appropriate chain for this session
93
- chain = session_state.get_chain(model_choice, vector_store)
94
-
95
- try:
96
- # Convert Gradio history to LangChain format
97
- for human, ai in history[-5:]: # Keep last 5 exchanges as memory
98
- chain.memory.save_context({"input": human}, {"output": ai})
99
 
100
  # Generate response
101
- result = chain.invoke({"question": message})
102
- response = result["answer"]
 
 
 
 
 
 
 
 
 
103
 
104
- return "", history + [(message, response)], session_state
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
- return "", history + [(message, f"⚠️ Error: {str(e)}")], session_state
108
 
109
- with gr.Blocks() as demo:
110
- gr.Markdown("# 🚀 Optimized Chat with Session Management")
111
 
112
- # UI Components
113
- model_dropdown = gr.Dropdown(
114
- list(MODEL_CONFIG.keys()),
115
- value="phi-3-mini",
116
- label="Select Model"
117
- )
118
- vector_store_dropdown = gr.Dropdown(
119
- ["llm", "scoliosis"],
120
- value="scoliosis",
121
- label="Knowledge Base"
122
- )
123
-
124
- # Session state stored in the browser
125
- session = gr.State()
126
-
127
- chatbot = gr.Chatbot(height=400)
128
- msg = gr.Textbox(label="Your Message")
129
- clear = gr.Button("Clear History")
 
 
 
 
 
 
130
 
131
- # Chat handlers
132
- msg.submit(
133
- respond,
134
- [msg, chatbot, model_dropdown, vector_store_dropdown, session],
135
- [msg, chatbot, session]
 
 
 
 
 
136
  )
137
 
138
- clear.click(
139
- lambda: ([], None),
140
- [],
141
- [chatbot, session]
142
- )
143
 
144
- demo.launch()
 
1
+ # app.py
2
  import gradio as gr
 
 
 
 
 
 
3
  import torch
4
+ import time
5
+ import os
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
7
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.vectorstores import FAISS
11
 
12
+ # Configuration
13
  MODEL_CONFIG = {
14
+ "phi-3": {
15
+ "model_name": "microsoft/phi-3-mini-4k-instruct",
16
+ "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
 
17
  },
18
+ "llama3-8b": {
19
+ "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
20
+ "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
21
+ {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
22
  }
23
  }
24
 
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ bnb_4bit_use_double_quant=True
30
+ )
31
+
32
+ class ChatModel:
33
+ def __init__(self):
34
+ self.models = {}
35
+ self.tokenizers = {}
36
+ self.vectorstore = None
37
+
38
+ def load_model(self, model_name):
39
+ if model_name not in self.models:
40
+ config = MODEL_CONFIG[model_name]
41
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ config["model_name"],
45
+ quantization_config=bnb_config,
46
+ device_map="auto",
47
+ torch_dtype=torch.float16,
48
+ )
49
+ self.models[model_name] = model
50
+ self.tokenizers[model_name] = tokenizer
51
+
52
 
53
+
54
+ def load_vector_store(self, store_name):
55
+ """Cache vector stores in memory"""
 
56
 
57
+ if store_name not in self.vectorstore:
58
+ embeddings = HuggingFaceEmbeddings(
59
+ model_name="BAAI/bge-small-en-v1.5"
60
+ )
61
+ self.vectorstore[store_name] = FAISS.load_local(
62
+ f"vector_stores/{store_name}",
63
+ embeddings
64
+ )
65
+ return self.vectorstore[store_name]
66
+
67
+
68
+
69
+ def process_documents(self, files, progress=gr.Progress()):
70
+ """Process uploaded documents into vector embeddings"""
71
+ try:
72
+ progress(0, desc="Starting document processing")
73
+ documents = []
74
+
75
+ # Load documents
76
+ for file_path in progress.tqdm(files, desc="Loading files"):
77
+ if file_path.endswith(".pdf"):
78
+ loader = PyPDFLoader(file_path)
79
+ elif file_path.endswith(".txt"):
80
+ loader = TextLoader(file_path)
81
+ else:
82
+ continue
83
+ documents.extend(loader.load())
84
+
85
+ # Split documents
86
+ progress(0.3, desc="Processing documents")
87
+ text_splitter = RecursiveCharacterTextSplitter(
88
+ chunk_size=512,
89
+ chunk_overlap=50
90
+ )
91
+ texts = text_splitter.split_documents(documents)
92
+
93
+ # Create embeddings
94
+ progress(0.6, desc="Generating embeddings")
95
+ embeddings = HuggingFaceEmbeddings(
96
+ model_name="BAAI/bge-small-en-v1.5"
97
+ )
98
+
99
+ # Create vector store
100
+ progress(0.8, desc="Building vector database")
101
+ self.vectorstore = FAISS.from_documents(texts, embeddings)
102
+
103
+ return "✅ Documents processed successfully! Ready for queries."
104
 
105
+ except Exception as e:
106
+ return f"❌ Error processing documents: {str(e)}"
 
 
 
 
 
 
107
 
108
+ def generate(self, message, model_name, vector_store_name, history):
109
+ start_time = time.time()
110
+ self.load_model(model_name)
111
+ self.load_vector_store(vector_store_name)
112
+ config = MODEL_CONFIG[model_name]
113
+
114
+ # Retrieve relevant context
115
+ context = ""
116
+ if self.vectorstore[vector_store_name]:
117
+ docs = self.vectorstore[vector_store_name].similarity_search(message, k=3)
118
+ context = "\n\n".join([d.page_content for d in docs])
 
 
 
 
 
 
119
 
120
+ # Format prompt with context
121
+ prompt = config["template"].format(
122
+ message=f"Context:\n{context}\n\nQuestion: {message}"
 
 
123
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Generate response
126
+ pipe = pipeline(
127
+ "text-generation",
128
+ model=self.models[model_name],
129
+ tokenizer=self.tokenizers[model_name],
130
+ max_new_tokens=384,
131
+ temperature=0.7,
132
+ top_p=0.9,
133
+ repetition_penalty=1.1,
134
+ do_sample=True,
135
+ return_full_text=False
136
+ )
137
 
138
+ response = pipe(prompt)[0]['generated_text']
139
+
140
+ # Calculate metrics
141
+ elapsed_time = time.time() - start_time
142
+ tokens = len(self.tokenizers[model_name].encode(response))
143
+ tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
144
+
145
+ return response, elapsed_time, tokens_per_sec
146
+
147
+ # Initialize model handler
148
+ model_handler = ChatModel()
149
+
150
+ def chat(message, history, model_choice, vector_store_choice):
151
+ try:
152
+ response, response_time, token_speed = model_handler.generate(message, model_choice, vector_store_choice, history)
153
+ formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
154
+ return [(message, formatted_response)]
155
  except Exception as e:
156
+ return [(message, f"Error: {str(e)}")]
157
 
158
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
160
 
161
+ with gr.Row():
162
+ model_choice = gr.Dropdown(
163
+ choices=["phi-3", "llama3-8b"],
164
+ label="Select Model",
165
+ value="phi-3"
166
+ )
167
+ vector_store_choice = gr.Dropdown(
168
+ ["llm", "scoliosis"],
169
+ value="scoliosis",
170
+ label="Knowledge Base"
171
+ )
172
+
173
+ with gr.Row():
174
+ with gr.Column(scale=1):
175
+ file_upload = gr.File(
176
+ label="Upload Documents (PDF/TXT)",
177
+ file_count="multiple",
178
+ file_types=[".pdf", ".txt"],
179
+ type="filepath"
180
+ )
181
+ status = gr.Textbox(label="Processing Status", interactive=False)
182
+ with gr.Column(scale=3):
183
+ chatbot = gr.Chatbot(height=500)
184
+ msg = gr.Textbox(label="Message", placeholder="Type your question here...")
185
 
186
+ with gr.Row():
187
+ submit_btn = gr.Button("Send", variant="primary")
188
+ clear_btn = gr.ClearButton([msg, chatbot, file_upload])
189
+
190
+ # Event handlers
191
+ file_upload.upload(
192
+ fn=model_handler.process_documents,
193
+ inputs=file_upload,
194
+ outputs=status,
195
+ show_progress="full"
196
  )
197
 
198
+ msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
199
+ submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
 
 
 
200
 
201
+ demo.launch()