admin08077 commited on
Commit
96e74be
Β·
verified Β·
1 Parent(s): 9e910c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -91
app.py CHANGED
@@ -2,18 +2,24 @@ import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import json
4
  import os
 
 
 
 
 
 
5
 
6
- """
7
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
- """
9
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
11
  # Persistent memory and knowledge base setup
12
  memory_file = "chat_memory.json"
13
- knowledge_base = {
14
- "AI": "Artificial Intelligence is a branch of computer science that focuses on creating systems capable of performing tasks that typically require human intelligence.",
15
- "Quantum Computing": "Quantum computing is a type of computation that uses quantum mechanics to process information in ways classical computers cannot.",
16
- }
 
17
 
18
  # Load memory from file
19
  def load_memory():
@@ -25,109 +31,223 @@ def load_memory():
25
  # Save memory to file
26
  def save_memory(memory):
27
  with open(memory_file, "w") as f:
28
- json.dump(memory, f)
29
 
30
  # Append to memory
31
  def update_memory(conversation):
32
  memory = load_memory()
33
  memory.append(conversation)
 
 
 
34
  save_memory(memory)
35
 
36
- # Response generation with memory and knowledge base integration
37
- def respond(
38
- message,
39
- history: list[tuple[str, str]],
40
- system_message,
41
- max_tokens,
42
- temperature,
43
- top_p,
44
- ):
45
- messages = [{"role": "system", "content": system_message}]
46
-
47
- for val in history:
48
- if val[0]:
49
- messages.append({"role": "user", "content": val[0]})
50
- if val[1]:
51
- messages.append({"role": "assistant", "content": val[1]})
52
-
53
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Check for answers in the knowledge base
56
- if message in knowledge_base:
57
- response = knowledge_base[message]
58
- update_memory((message, response))
59
- yield response
60
- return
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Generate response from AI
63
  response = ""
64
- for message in client.chat_completion(
65
- messages,
66
- max_tokens=max_tokens,
67
- stream=True,
68
- temperature=temperature,
69
- top_p=top_p,
70
- ):
71
- token = message.choices[0].delta.content
72
- response += token
73
- yield response
 
 
 
 
74
 
75
  # Update memory
76
- update_memory((message, response))
77
-
78
-
79
- # Gradio interface with enhanced functionality
80
- def add_to_knowledge_base(key, value):
81
- knowledge_base[key] = value
82
- return f"Added to knowledge base: {key} -> {value}"
83
-
84
-
85
- demo = gr.Blocks()
86
-
87
- with demo:
88
- gr.Markdown("# Advanced Chatbot with Memory and Knowledge Base")
89
-
90
- with gr.Tab("Chat"):
91
- chatbot = gr.ChatInterface(
92
- respond,
93
- additional_inputs=[
94
- gr.Textbox(value="You are an advanced AI Chatbot.", label="System message"),
95
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
96
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
97
- gr.Slider(
98
- minimum=0.1,
99
- maximum=1.0,
100
- value=0.95,
101
- step=0.05,
102
- label="Top-p (nucleus sampling)",
103
- ),
104
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
 
107
- with gr.Tab("Knowledge Base"):
108
- gr.Markdown("### Manage the Knowledge Base")
109
- kb_key = gr.Textbox(label="Key", placeholder="Enter the topic or question")
110
- kb_value = gr.Textbox(label="Value", placeholder="Enter the explanation or answer")
111
- add_kb_button = gr.Button("Add to Knowledge Base")
112
- kb_output = gr.Textbox(label="Knowledge Base Output")
113
-
114
- add_kb_button.click(add_to_knowledge_base, [kb_key, kb_value], kb_output)
115
-
116
- with gr.Tab("Memory"):
117
- gr.Markdown("### Conversation Memory")
118
- memory_display = gr.Textbox(label="Conversation Memory", lines=10)
119
- refresh_memory = gr.Button("Refresh Memory")
120
- clear_memory = gr.Button("Clear Memory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def display_memory():
123
- return json.dumps(load_memory(), indent=2)
124
 
125
  def clear_memory_func():
126
  save_memory([])
127
- return "Memory Cleared!"
128
-
129
- refresh_memory.click(display_memory, outputs=memory_display)
130
- clear_memory.click(clear_memory_func, outputs=memory_display)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
  demo.launch()
 
2
  from huggingface_hub import InferenceClient
3
  import json
4
  import os
5
+ import shutil
6
+ import pandas as pd
7
+ from sklearn.feature_extraction.text import CountVectorizer
8
+ from sklearn.ensemble import RandomForestClassifier
9
+ from sklearn.pipeline import Pipeline
10
+ import joblib
11
 
12
+ # Initialize the HuggingFace API Client with a valid model
13
+ # Replace 'gpt-3.5-turbo' with your desired model if different
14
+ client = InferenceClient("gpt-3.5-turbo")
 
15
 
16
  # Persistent memory and knowledge base setup
17
  memory_file = "chat_memory.json"
18
+ knowledge_base_dir = "knowledge_base"
19
+ model_file = "chat_model.pkl"
20
+
21
+ # Ensure directories exist
22
+ os.makedirs(knowledge_base_dir, exist_ok=True)
23
 
24
  # Load memory from file
25
  def load_memory():
 
31
  # Save memory to file
32
  def save_memory(memory):
33
  with open(memory_file, "w") as f:
34
+ json.dump(memory, f, indent=2)
35
 
36
  # Append to memory
37
  def update_memory(conversation):
38
  memory = load_memory()
39
  memory.append(conversation)
40
+ # Optionally limit memory size
41
+ if len(memory) > 1000:
42
+ memory = memory[-1000:]
43
  save_memory(memory)
44
 
45
+ # Load or initialize the ML model
46
+ def load_or_initialize_model():
47
+ if os.path.exists(model_file):
48
+ return joblib.load(model_file)
49
+ return Pipeline([
50
+ ("vectorizer", CountVectorizer()),
51
+ ("classifier", RandomForestClassifier(n_estimators=100, random_state=42))
52
+ ])
53
+
54
+ # Retrain model on files in the knowledge base
55
+ def train_model_on_files():
56
+ model = load_or_initialize_model()
57
+ texts, labels = [], []
58
+
59
+ # Load data from the knowledge base
60
+ for file_name in os.listdir(knowledge_base_dir):
61
+ file_path = os.path.join(knowledge_base_dir, file_name)
62
+ if file_path.endswith(".csv"):
63
+ try:
64
+ df = pd.read_csv(file_path)
65
+ if "text" in df.columns and "label" in df.columns:
66
+ texts.extend(df["text"].astype(str).tolist())
67
+ labels.extend(df["label"].astype(str).tolist())
68
+ else:
69
+ return f"File '{file_name}' does not contain required 'text' and 'label' columns."
70
+ except Exception as e:
71
+ return f"Error reading '{file_name}': {str(e)}"
72
+
73
+ if texts and labels:
74
+ try:
75
+ model.fit(texts, labels)
76
+ joblib.dump(model, model_file)
77
+ return f"Model trained on {len(texts)} samples from {len(os.listdir(knowledge_base_dir))} files."
78
+ except Exception as e:
79
+ return f"Error during model training: {str(e)}"
80
+ return "No valid training data found in the knowledge base."
81
+
82
+ # Chat response function
83
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
84
+ # Load or initialize model
85
+ model = load_or_initialize_model()
86
 
87
  # Check for answers in the knowledge base
88
+ try:
89
+ pred_label = model.predict([message])[0]
90
+ response = f"Predicted response: {pred_label}"
91
+ update_memory({"user": message, "assistant": response})
92
+ return response
93
+ except Exception:
94
+ pass # Continue with GPT model if ML model doesn't have a response
95
+
96
+ # Generate response using GPT
97
+ messages = [{"role": "system", "content": system_message}]
98
+ for turn in history:
99
+ if turn["user"]:
100
+ messages.append({"role": "user", "content": turn["user"]})
101
+ if turn["assistant"]:
102
+ messages.append({"role": "assistant", "content": turn["assistant"]})
103
+ messages.append({"role": "user", "content": message})
104
 
 
105
  response = ""
106
+ try:
107
+ for message_part in client.chat_completion(
108
+ messages,
109
+ max_tokens=max_tokens,
110
+ stream=True,
111
+ temperature=temperature,
112
+ top_p=top_p,
113
+ ):
114
+ token = message_part.get("choices", [{}])[0].get("delta", {}).get("content", "")
115
+ response += token
116
+ except Exception as e:
117
+ response = f"Error generating response: {str(e)}"
118
+ update_memory({"user": message, "assistant": response})
119
+ return response
120
 
121
  # Update memory
122
+ update_memory({"user": message, "assistant": response})
123
+ return response
124
+
125
+ # Gradio interface
126
+ with gr.Blocks() as demo:
127
+ gr.Markdown("# 🧠 Advanced AI Chatbot with Knowledge Base and Model Training")
128
+
129
+ with gr.Tab("πŸ’¬ Chat"):
130
+ chatbot = gr.Chatbot(label="AI Chatbot").style(height=600)
131
+ with gr.Row():
132
+ with gr.Column(scale=0.85):
133
+ user_input = gr.Textbox(
134
+ label="Your Message",
135
+ placeholder="Type your message here...",
136
+ )
137
+ with gr.Column(scale=0.15, min_width=100):
138
+ send_button = gr.Button("Send", variant="primary")
139
+ with gr.Row():
140
+ system_message = gr.Textbox(
141
+ value="You are an advanced AI Chatbot.",
142
+ label="System Message",
143
+ visible=False # Hidden if default system message is used
144
+ )
145
+ max_tokens = gr.Slider(
146
+ minimum=100, maximum=2048, value=512, step=100, label="Max Tokens"
147
+ )
148
+ temperature = gr.Slider(
149
+ minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
150
+ )
151
+ top_p = gr.Slider(
152
+ minimum=0.1,
153
+ maximum=1.0,
154
+ value=0.95,
155
+ step=0.05,
156
+ label="Top-p (Nucleus Sampling)",
157
+ )
158
+
159
+ def handle_message(message, history, system_message, max_tokens, temperature, top_p):
160
+ response = respond(message, history, system_message, max_tokens, temperature, top_p)
161
+ history.append({"user": message, "assistant": response})
162
+ return history, history
163
+
164
+ send_button.click(
165
+ handle_message,
166
+ inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p],
167
+ outputs=[chatbot, chatbot],
168
+ )
169
+ user_input.submit(
170
+ handle_message,
171
+ inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p],
172
+ outputs=[chatbot, chatbot],
173
  )
174
 
175
+ with gr.Tab("πŸ“š Knowledge Base"):
176
+ gr.Markdown("### Manage Knowledge Base")
177
+ file_upload = gr.File(
178
+ label="Upload CSV File",
179
+ file_types=[".csv"],
180
+ multiple=False,
181
+ interactive=True,
182
+ )
183
+ upload_output = gr.Textbox(label="Upload Result", interactive=False)
184
+ train_button = gr.Button("πŸ”„ Train Model on Knowledge Base")
185
+ train_output = gr.Textbox(label="Training Result", interactive=False)
186
+
187
+ def upload_file(file):
188
+ if file is None:
189
+ return "No file uploaded."
190
+ try:
191
+ # Validate file extension
192
+ if not file.name.endswith(".csv"):
193
+ return "Invalid file type. Please upload a CSV file."
194
+ # Save file to knowledge base directory
195
+ destination_path = os.path.join(knowledge_base_dir, file.name)
196
+ with open(destination_path, "wb") as f:
197
+ f.write(file.read())
198
+ return f"File '{file.name}' uploaded successfully."
199
+ except Exception as e:
200
+ return f"Error uploading file: {str(e)}"
201
+
202
+ file_upload.change(upload_file, inputs=file_upload, outputs=upload_output)
203
+ train_button.click(train_model_on_files, inputs=None, outputs=train_output)
204
+
205
+ with gr.Tab("🧠 Memory"):
206
+ gr.Markdown("### View and Manage Conversation Memory")
207
+ memory_display = gr.JSON(label="Conversation Memory", interactive=False)
208
+ with gr.Row():
209
+ refresh_memory = gr.Button("πŸ”„ Refresh Memory")
210
+ clear_memory = gr.Button("πŸ—‘οΈ Clear Memory")
211
+ export_memory = gr.Button("πŸ“€ Export Memory")
212
+ export_output = gr.File(label="Download Memory", visible=False)
213
 
214
  def display_memory():
215
+ return load_memory()
216
 
217
  def clear_memory_func():
218
  save_memory([])
219
+ return []
220
+
221
+ def export_memory_func():
222
+ if os.path.exists(memory_file):
223
+ return memory_file
224
+ return None
225
+
226
+ refresh_memory.click(display_memory, inputs=None, outputs=memory_display)
227
+ clear_memory.click(clear_memory_func, inputs=None, outputs=memory_display)
228
+ export_memory.click(export_memory_func, inputs=None, outputs=export_output)
229
+
230
+ with gr.Tab("πŸ’Ύ Download Model"):
231
+ gr.Markdown("### Download the Trained Model")
232
+ download_button = gr.Button("πŸ“₯ Download Model")
233
+ model_download_output = gr.File(label="Downloadable Model", interactive=False)
234
+
235
+ def download_model():
236
+ if os.path.exists(model_file):
237
+ return model_file
238
+ return None
239
+
240
+ download_button.click(download_model, inputs=None, outputs=model_download_output)
241
+
242
+ with gr.Tab("βš™οΈ Settings"):
243
+ gr.Markdown("### Application Settings")
244
+ # Additional settings can be added here
245
+ gr.Textbox(
246
+ value="",
247
+ label="Settings Placeholder",
248
+ placeholder="Add settings here...",
249
+ interactive=False,
250
+ )
251
 
252
  if __name__ == "__main__":
253
  demo.launch()