Wenye He commited on
Commit
7167bd9
·
verified ·
1 Parent(s): 2499b05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -152
app.py CHANGED
@@ -1,178 +1,144 @@
1
- # app.py
2
  import gradio as gr
 
 
 
 
 
 
3
  import torch
4
- import time
5
- import os
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
7
- from langchain_community.document_loaders import PyPDFLoader, TextLoader
8
- from langchain_text_splitters import RecursiveCharacterTextSplitter
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_community.vectorstores import FAISS
11
 
12
- # Configuration
13
  MODEL_CONFIG = {
14
- "phi-3": {
15
- "model_name": "microsoft/phi-3-mini-4k-instruct",
16
- "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
 
17
  },
18
- "llama3-8b": {
19
- "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
20
- "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
21
- {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
22
  }
23
  }
24
 
25
- bnb_config = BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_quant_type="nf4",
28
- bnb_4bit_compute_dtype=torch.float16,
29
- bnb_4bit_use_double_quant=True
30
- )
31
 
32
- class ChatModel:
33
- def __init__(self):
34
- self.models = {}
35
- self.tokenizers = {}
36
- self.vectorstore = None
37
-
38
- def load_model(self, model_name):
39
- if model_name not in self.models:
40
- config = MODEL_CONFIG[model_name]
41
- tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
42
- tokenizer.pad_token = tokenizer.eos_token
43
- model = AutoModelForCausalLM.from_pretrained(
44
- config["model_name"],
45
- quantization_config=bnb_config,
46
- device_map="auto",
47
- torch_dtype=torch.float16,
48
- )
49
- self.models[model_name] = model
50
- self.tokenizers[model_name] = tokenizer
51
-
52
- def process_documents(self, files, progress=gr.Progress()):
53
- """Process uploaded documents into vector embeddings"""
54
- try:
55
- progress(0, desc="Starting document processing")
56
- documents = []
57
-
58
- # Load documents
59
- for file_path in progress.tqdm(files, desc="Loading files"):
60
- if file_path.endswith(".pdf"):
61
- loader = PyPDFLoader(file_path)
62
- elif file_path.endswith(".txt"):
63
- loader = TextLoader(file_path)
64
- else:
65
- continue
66
- documents.extend(loader.load())
67
-
68
- # Split documents
69
- progress(0.3, desc="Processing documents")
70
- text_splitter = RecursiveCharacterTextSplitter(
71
- chunk_size=512,
72
- chunk_overlap=50
73
- )
74
- texts = text_splitter.split_documents(documents)
75
-
76
- # Create embeddings
77
- progress(0.6, desc="Generating embeddings")
78
- embeddings = HuggingFaceEmbeddings(
79
- model_name="BAAI/bge-small-en-v1.5"
80
- )
81
-
82
- # Create vector store
83
- progress(0.8, desc="Building vector database")
84
- self.vectorstore = FAISS.from_documents(texts, embeddings)
85
-
86
- return "✅ Documents processed successfully! Ready for queries."
87
-
88
- except Exception as e:
89
- return f"❌ Error processing documents: {str(e)}"
90
 
91
- def generate(self, message, model_name, history):
92
- start_time = time.time()
93
- self.load_model(model_name)
94
- config = MODEL_CONFIG[model_name]
95
-
96
- # Retrieve relevant context
97
- context = ""
98
- if self.vectorstore:
99
- docs = self.vectorstore.similarity_search(message, k=3)
100
- context = "\n\n".join([d.page_content for d in docs])
101
 
102
- # Format prompt with context
103
- prompt = config["template"].format(
104
- message=f"Context:\n{context}\n\nQuestion: {message}"
 
 
105
  )
106
 
107
- # Generate response
108
- pipe = pipeline(
109
  "text-generation",
110
- model=self.models[model_name],
111
- tokenizer=self.tokenizers[model_name],
112
- max_new_tokens=384,
113
- temperature=0.7,
114
- top_p=0.9,
115
- repetition_penalty=1.1,
116
- do_sample=True,
117
- return_full_text=False
118
  )
119
-
120
- response = pipe(prompt)[0]['generated_text']
121
-
122
- # Calculate metrics
123
- elapsed_time = time.time() - start_time
124
- tokens = len(self.tokenizers[model_name].encode(response))
125
- tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
126
-
127
- return response, elapsed_time, tokens_per_sec
128
 
129
- # Initialize model handler
130
- model_handler = ChatModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def chat(message, history, model_choice):
 
 
 
 
 
 
 
 
133
  try:
134
- response, response_time, token_speed = model_handler.generate(message, model_choice, history)
135
- formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
136
- return [(message, formatted_response)]
 
 
 
 
 
 
 
137
  except Exception as e:
138
- return [(message, f"Error: {str(e)}")]
139
 
140
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
141
- gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
142
 
143
- with gr.Row():
144
- model_choice = gr.Dropdown(
145
- choices=["phi-3", "llama3-8b"],
146
- label="Select Model",
147
- value="phi-3"
148
- )
 
 
 
 
 
149
 
150
- with gr.Row():
151
- with gr.Column(scale=1):
152
- file_upload = gr.File(
153
- label="Upload Documents (PDF/TXT)",
154
- file_count="multiple",
155
- file_types=[".pdf", ".txt"],
156
- type="filepath"
157
- )
158
- status = gr.Textbox(label="Processing Status", interactive=False)
159
- with gr.Column(scale=3):
160
- chatbot = gr.Chatbot(height=500)
161
- msg = gr.Textbox(label="Message", placeholder="Type your question here...")
162
 
163
- with gr.Row():
164
- submit_btn = gr.Button("Send", variant="primary")
165
- clear_btn = gr.ClearButton([msg, chatbot, file_upload])
166
-
167
- # Event handlers
168
- file_upload.upload(
169
- fn=model_handler.process_documents,
170
- inputs=file_upload,
171
- outputs=status,
172
- show_progress="full"
173
  )
174
 
175
- msg.submit(chat, [msg, chatbot, model_choice], chatbot)
176
- submit_btn.click(chat, [msg, chatbot, model_choice], chatbot)
 
 
 
177
 
178
- demo.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.llms import HuggingFacePipeline
8
  import torch
 
 
 
 
 
 
 
9
 
10
+ # Model Configuration
11
  MODEL_CONFIG = {
12
+ "phi-3-mini": {
13
+ "name": "microsoft/phi-3-mini-128k-instruct",
14
+ "max_tokens": 1024,
15
+ "temperature": 0.8
16
  },
17
+ "Mistral-7B": {
18
+ "name": "mistralai/Mistral-7B-Instruct-v0.3",
19
+ "max_tokens": 512,
20
+ "temperature": 0.7
21
  }
22
  }
23
 
24
+ # Cache Stores
25
+ vector_store_cache = {}
26
+ model_pipeline_cache = {}
27
+ embedder = HuggingFaceEmbeddings()
 
 
28
 
29
+ def load_vector_store(store_name):
30
+ """Cache vector stores in memory"""
31
+ if store_name not in vector_store_cache:
32
+ vector_store_cache[store_name] = FAISS.load_local(
33
+ f"vector_stores/{store_name}",
34
+ embedder
35
+ )
36
+ return vector_store_cache[store_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def get_model_pipeline(model_choice):
39
+ """Cache model pipelines in memory"""
40
+ if model_choice not in model_pipeline_cache:
41
+ cfg = MODEL_CONFIG[model_choice]
 
 
 
 
 
 
42
 
43
+ tokenizer = AutoTokenizer.from_pretrained(cfg["name"])
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ cfg["name"],
46
+ device_map="auto",
47
+ torch_dtype="auto" if "phi-3" in model_choice else torch.float16
48
  )
49
 
50
+ model_pipeline_cache[model_choice] = pipeline(
 
51
  "text-generation",
52
+ model=model,
53
+ tokenizer=tokenizer,
54
+ max_new_tokens=cfg["max_tokens"],
55
+ temperature=cfg["temperature"]
 
 
 
 
56
  )
57
+ return model_pipeline_cache[model_choice]
 
 
 
 
 
 
 
 
58
 
59
+ class SessionChain:
60
+ """Per-session chain manager with memory"""
61
+ def __init__(self):
62
+ self.current_model = None
63
+ self.current_vector_store = None
64
+ self.chain = None
65
+
66
+ def get_chain(self, model_choice, vector_store_name):
67
+ """Get or create chain with proper configuration"""
68
+ if self.current_model != model_choice or self.current_vector_store != vector_store_name:
69
+ self._create_new_chain(model_choice, vector_store_name)
70
+ return self.chain
71
+
72
+ def _create_new_chain(self, model_choice, vector_store_name):
73
+ """Create new chain with updated configuration"""
74
+ vector_store = load_vector_store(vector_store_name)
75
+ pipe = get_model_pipeline(model_choice)
76
+
77
+ self.chain = ConversationalRetrievalChain.from_llm(
78
+ llm=HuggingFacePipeline(pipeline=pipe),
79
+ retriever=vector_store.as_retriever(),
80
+ memory=ConversationBufferMemory(),
81
+ verbose=False
82
+ )
83
+ self.current_model = model_choice
84
+ self.current_vector_store = vector_store_name
85
 
86
+ def respond(message, history, model_choice, vector_store, session_state):
87
+ """Handle message with cached resources and session chain"""
88
+ # Initialize session chain if not exists
89
+ if session_state is None:
90
+ session_state = SessionChain()
91
+
92
+ # Get the appropriate chain for this session
93
+ chain = session_state.get_chain(model_choice, vector_store)
94
+
95
  try:
96
+ # Convert Gradio history to LangChain format
97
+ for human, ai in history[-5:]: # Keep last 5 exchanges as memory
98
+ chain.memory.save_context({"input": human}, {"output": ai})
99
+
100
+ # Generate response
101
+ result = chain.invoke({"question": message})
102
+ response = result["answer"]
103
+
104
+ return "", history + [(message, response)], session_state
105
+
106
  except Exception as e:
107
+ return "", history + [(message, f"⚠️ Error: {str(e)}")], session_state
108
 
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown("# 🚀 Optimized Chat with Session Management")
111
 
112
+ # UI Components
113
+ model_dropdown = gr.Dropdown(
114
+ list(MODEL_CONFIG.keys()),
115
+ value="phi-3-mini",
116
+ label="Select Model"
117
+ )
118
+ vector_store_dropdown = gr.Dropdown(
119
+ ["legal_docs", "tech_docs"],
120
+ value="tech_docs",
121
+ label="Knowledge Base"
122
+ )
123
 
124
+ # Session state stored in the browser
125
+ session = gr.State()
 
 
 
 
 
 
 
 
 
 
126
 
127
+ chatbot = gr.Chatbot(height=400)
128
+ msg = gr.Textbox(label="Your Message")
129
+ clear = gr.Button("Clear History")
130
+
131
+ # Chat handlers
132
+ msg.submit(
133
+ respond,
134
+ [msg, chatbot, model_dropdown, vector_store_dropdown, session],
135
+ [msg, chatbot, session]
 
136
  )
137
 
138
+ clear.click(
139
+ lambda: ([], None),
140
+ [],
141
+ [chatbot, session]
142
+ )
143
 
144
+ demo.launch()