Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,207 +1,206 @@
|
|
| 1 |
-
import
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 4 |
import time
|
| 5 |
import os
|
| 6 |
-
from threading import Thread
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
import pandas as pd
|
| 9 |
import json
|
|
|
|
|
|
|
| 10 |
import re
|
|
|
|
|
|
|
| 11 |
from io import StringIO
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
.chat-message.user {
|
| 32 |
-
background-color: #e0f7fa;
|
| 33 |
-
border-left: 5px solid #039be5;
|
| 34 |
-
}
|
| 35 |
-
.chat-message.assistant {
|
| 36 |
-
background-color: #f1f8e9;
|
| 37 |
-
border-left: 5px solid #7cb342;
|
| 38 |
-
}
|
| 39 |
-
.chat-message .content {
|
| 40 |
-
margin-top: 0.5rem;
|
| 41 |
-
}
|
| 42 |
-
.sidebar-content {
|
| 43 |
-
padding: 1rem;
|
| 44 |
-
}
|
| 45 |
-
.sidebar-header {
|
| 46 |
-
font-weight: bold;
|
| 47 |
-
margin-bottom: 0.5rem;
|
| 48 |
-
}
|
| 49 |
-
</style>
|
| 50 |
-
""", unsafe_allow_html=True)
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
if 'model' not in st.session_state:
|
| 56 |
-
st.session_state.model = None
|
| 57 |
-
if 'tokenizer' not in st.session_state:
|
| 58 |
-
st.session_state.tokenizer = None
|
| 59 |
-
if 'pipe' not in st.session_state:
|
| 60 |
-
st.session_state.pipe = None
|
| 61 |
-
if 'loading_model' not in st.session_state:
|
| 62 |
-
st.session_state.loading_model = False
|
| 63 |
-
if 'model_loaded' not in st.session_state:
|
| 64 |
-
st.session_state.model_loaded = False
|
| 65 |
-
if 'system_prompt' not in st.session_state:
|
| 66 |
-
st.session_state.system_prompt = "You are a helpful AI assistant based on the Mistral-7B-Instruct model. Answer the user's questions to the best of your ability."
|
| 67 |
-
if 'generate_config' not in st.session_state:
|
| 68 |
-
st.session_state.generate_config = {
|
| 69 |
-
"max_new_tokens": 512,
|
| 70 |
-
"temperature": 0.7,
|
| 71 |
-
"top_p": 0.95,
|
| 72 |
-
"top_k": 50,
|
| 73 |
-
"repetition_penalty": 1.1,
|
| 74 |
-
"do_sample": True
|
| 75 |
-
}
|
| 76 |
-
if 'chats' not in st.session_state:
|
| 77 |
-
st.session_state.chats = {"Main Chat": []}
|
| 78 |
-
if 'current_chat' not in st.session_state:
|
| 79 |
-
st.session_state.current_chat = "Main Chat"
|
| 80 |
-
if 'file_data' not in st.session_state:
|
| 81 |
-
st.session_state.file_data = None
|
| 82 |
-
if 'analyzed_data' not in st.session_state:
|
| 83 |
-
st.session_state.analyzed_data = None
|
| 84 |
|
| 85 |
# Function to load the model in background
|
| 86 |
def load_model_in_background():
|
|
|
|
| 87 |
try:
|
| 88 |
-
|
|
|
|
| 89 |
# Model identifier
|
| 90 |
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 91 |
|
| 92 |
# Load tokenizer
|
| 93 |
-
|
| 94 |
-
st.session_state.tokenizer = tokenizer
|
| 95 |
|
| 96 |
# Configure model loading (with lower precision for efficiency)
|
| 97 |
-
|
| 98 |
model_id,
|
| 99 |
torch_dtype=torch.float16,
|
| 100 |
device_map="auto",
|
| 101 |
low_cpu_mem_usage=True
|
| 102 |
)
|
| 103 |
-
st.session_state.model = model
|
| 104 |
|
| 105 |
# Create text generation pipeline
|
| 106 |
-
|
| 107 |
"text-generation",
|
| 108 |
-
model=
|
| 109 |
-
tokenizer=
|
| 110 |
return_full_text=False
|
| 111 |
)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
except Exception as e:
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# Function to format and display chat messages
|
| 120 |
-
def display_messages():
|
| 121 |
-
for message in st.session_state.chats[st.session_state.current_chat]:
|
| 122 |
-
with st.container():
|
| 123 |
-
col1, col2 = st.columns([1, 12])
|
| 124 |
-
|
| 125 |
-
if message["role"] == "user":
|
| 126 |
-
with st.chat_message("user"):
|
| 127 |
-
st.markdown(message["content"])
|
| 128 |
-
elif message["role"] == "assistant":
|
| 129 |
-
with st.chat_message("assistant"):
|
| 130 |
-
st.markdown(message["content"])
|
| 131 |
-
elif message["role"] == "system":
|
| 132 |
-
with st.chat_message("system"):
|
| 133 |
-
st.markdown(f"🔧 **System:** {message['content']}")
|
| 134 |
|
| 135 |
# Function to generate response using the model
|
| 136 |
-
def generate_response(prompt):
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
return "Model is still loading. Please wait a moment before sending messages."
|
| 139 |
|
| 140 |
try:
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
# Format conversation history in Mistral's chat format
|
| 144 |
conversation = []
|
| 145 |
|
| 146 |
# Add system prompt if it exists
|
| 147 |
-
if
|
| 148 |
-
conversation.append({"role": "system", "content":
|
| 149 |
|
| 150 |
# Add previous messages
|
| 151 |
for msg in messages:
|
| 152 |
if msg["role"] != "system": # Skip system messages in the history
|
| 153 |
conversation.append({"role": msg["role"], "content": msg["content"]})
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Add current prompt
|
| 156 |
-
conversation.append({"role": "user", "content":
|
| 157 |
|
| 158 |
# Convert to Mistral's chat format
|
| 159 |
-
formatted_prompt =
|
| 160 |
conversation,
|
| 161 |
tokenize=False,
|
| 162 |
add_generation_prompt=True
|
| 163 |
)
|
| 164 |
|
| 165 |
-
# Generate response
|
| 166 |
-
|
| 167 |
|
| 168 |
-
|
|
|
|
| 169 |
formatted_prompt,
|
| 170 |
-
max_new_tokens=
|
| 171 |
-
temperature=
|
| 172 |
-
top_p=
|
| 173 |
-
top_k=
|
| 174 |
-
repetition_penalty=
|
| 175 |
-
do_sample=
|
| 176 |
)
|
| 177 |
|
| 178 |
-
|
|
|
|
|
|
|
| 179 |
generated_text = response[0]["generated_text"]
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
except Exception as e:
|
| 183 |
-
|
|
|
|
|
|
|
| 184 |
|
| 185 |
# Function to create a new chat
|
| 186 |
def create_new_chat(chat_name):
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Function to handle file upload and analysis
|
| 194 |
def analyze_uploaded_file(file):
|
| 195 |
-
|
| 196 |
-
return None
|
| 197 |
|
| 198 |
-
|
|
|
|
| 199 |
|
| 200 |
try:
|
|
|
|
|
|
|
| 201 |
if file_extension == 'csv':
|
| 202 |
-
data = pd.read_csv(file)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
"type": "csv",
|
| 206 |
"data": data,
|
| 207 |
"summary": {
|
|
@@ -214,9 +213,10 @@ def analyze_uploaded_file(file):
|
|
| 214 |
}
|
| 215 |
|
| 216 |
elif file_extension in ['txt', 'md']:
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
| 220 |
"type": "text",
|
| 221 |
"data": content,
|
| 222 |
"summary": {
|
|
@@ -228,10 +228,11 @@ def analyze_uploaded_file(file):
|
|
| 228 |
}
|
| 229 |
|
| 230 |
elif file_extension == 'json':
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
"type": "json",
|
| 236 |
"data": data,
|
| 237 |
"summary": {
|
|
@@ -243,9 +244,9 @@ def analyze_uploaded_file(file):
|
|
| 243 |
}
|
| 244 |
|
| 245 |
elif file_extension in ['xls', 'xlsx']:
|
| 246 |
-
data = pd.read_excel(file)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
"type": "excel",
|
| 250 |
"data": data,
|
| 251 |
"summary": {
|
|
@@ -258,349 +259,297 @@ def analyze_uploaded_file(file):
|
|
| 258 |
}
|
| 259 |
|
| 260 |
else:
|
| 261 |
-
return {
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
except Exception as e:
|
| 267 |
-
return {
|
| 268 |
-
"type": "error",
|
| 269 |
-
"error": str(e)
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
# Sidebar - Model Control and Settings
|
| 273 |
-
with st.sidebar:
|
| 274 |
-
st.title("🤖 Mistral-7B Chatbot")
|
| 275 |
-
|
| 276 |
-
# Model loading section
|
| 277 |
-
st.header("Model Control")
|
| 278 |
-
if not st.session_state.model_loaded and not st.session_state.loading_model:
|
| 279 |
-
if st.button("Load Mistral-7B Model"):
|
| 280 |
-
loading_thread = Thread(target=load_model_in_background)
|
| 281 |
-
loading_thread.start()
|
| 282 |
-
|
| 283 |
-
if st.session_state.loading_model:
|
| 284 |
-
st.info("Loading model... This may take a few minutes.")
|
| 285 |
-
progress_bar = st.progress(0)
|
| 286 |
-
for i in range(100):
|
| 287 |
-
if not st.session_state.loading_model:
|
| 288 |
-
progress_bar.progress(100)
|
| 289 |
-
st.success("Model loaded successfully!")
|
| 290 |
-
break
|
| 291 |
-
progress_bar.progress(i + 1)
|
| 292 |
-
time.sleep(0.1)
|
| 293 |
-
|
| 294 |
-
if st.session_state.model_loaded:
|
| 295 |
-
st.success("Model loaded and ready!")
|
| 296 |
-
|
| 297 |
-
# Chat management
|
| 298 |
-
st.header("Chat Management")
|
| 299 |
-
|
| 300 |
-
# Select chat
|
| 301 |
-
selected_chat = st.selectbox(
|
| 302 |
-
"Select Chat",
|
| 303 |
-
options=list(st.session_state.chats.keys()),
|
| 304 |
-
index=list(st.session_state.chats.keys()).index(st.session_state.current_chat)
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
if selected_chat != st.session_state.current_chat:
|
| 308 |
-
st.session_state.current_chat = selected_chat
|
| 309 |
-
|
| 310 |
-
# Create new chat
|
| 311 |
-
new_chat_name = st.text_input("New Chat Name")
|
| 312 |
-
if st.button("Create New Chat"):
|
| 313 |
-
if create_new_chat(new_chat_name):
|
| 314 |
-
st.success(f"Created new chat: {new_chat_name}")
|
| 315 |
-
else:
|
| 316 |
-
st.error("Please enter a unique chat name")
|
| 317 |
-
|
| 318 |
-
# Clear current chat
|
| 319 |
-
if st.button("Clear Current Chat"):
|
| 320 |
-
st.session_state.chats[st.session_state.current_chat] = []
|
| 321 |
-
st.success(f"Cleared chat: {st.session_state.current_chat}")
|
| 322 |
-
|
| 323 |
-
# Advanced settings
|
| 324 |
-
st.header("Model Settings")
|
| 325 |
-
|
| 326 |
-
with st.expander("System Prompt"):
|
| 327 |
-
system_prompt = st.text_area(
|
| 328 |
-
"Set the AI's behavior",
|
| 329 |
-
value=st.session_state.system_prompt,
|
| 330 |
-
height=100
|
| 331 |
-
)
|
| 332 |
-
if st.button("Update System Prompt"):
|
| 333 |
-
st.session_state.system_prompt = system_prompt
|
| 334 |
-
st.success("System prompt updated!")
|
| 335 |
-
|
| 336 |
-
with st.expander("Generation Parameters"):
|
| 337 |
-
st.slider(
|
| 338 |
-
"Temperature",
|
| 339 |
-
min_value=0.1,
|
| 340 |
-
max_value=2.0,
|
| 341 |
-
value=st.session_state.generate_config["temperature"],
|
| 342 |
-
step=0.1,
|
| 343 |
-
key="temp_slider",
|
| 344 |
-
help="Higher = more creative, Lower = more focused"
|
| 345 |
-
)
|
| 346 |
-
st.session_state.generate_config["temperature"] = st.session_state["temp_slider"]
|
| 347 |
-
|
| 348 |
-
st.slider(
|
| 349 |
-
"Max Tokens",
|
| 350 |
-
min_value=64,
|
| 351 |
-
max_value=2048,
|
| 352 |
-
value=st.session_state.generate_config["max_new_tokens"],
|
| 353 |
-
step=64,
|
| 354 |
-
key="max_tokens_slider",
|
| 355 |
-
help="Maximum number of tokens in the response"
|
| 356 |
-
)
|
| 357 |
-
st.session_state.generate_config["max_new_tokens"] = st.session_state["max_tokens_slider"]
|
| 358 |
-
|
| 359 |
-
st.slider(
|
| 360 |
-
"Top P",
|
| 361 |
-
min_value=0.1,
|
| 362 |
-
max_value=1.0,
|
| 363 |
-
value=st.session_state.generate_config["top_p"],
|
| 364 |
-
step=0.05,
|
| 365 |
-
key="top_p_slider",
|
| 366 |
-
help="Nucleus sampling parameter"
|
| 367 |
-
)
|
| 368 |
-
st.session_state.generate_config["top_p"] = st.session_state["top_p_slider"]
|
| 369 |
-
|
| 370 |
-
st.slider(
|
| 371 |
-
"Repetition Penalty",
|
| 372 |
-
min_value=1.0,
|
| 373 |
-
max_value=2.0,
|
| 374 |
-
value=st.session_state.generate_config["repetition_penalty"],
|
| 375 |
-
step=0.1,
|
| 376 |
-
key="rep_pen_slider",
|
| 377 |
-
help="Penalizes repetition (higher = less repetition)"
|
| 378 |
-
)
|
| 379 |
-
st.session_state.generate_config["repetition_penalty"] = st.session_state["rep_pen_slider"]
|
| 380 |
-
|
| 381 |
-
# File upload and analysis
|
| 382 |
-
st.header("File Analysis")
|
| 383 |
-
uploaded_file = st.file_uploader("Upload a file to analyze",
|
| 384 |
-
type=["csv", "txt", "json", "xlsx", "xls", "md"])
|
| 385 |
-
|
| 386 |
-
if uploaded_file is not None:
|
| 387 |
-
if st.button("Analyze File"):
|
| 388 |
-
with st.spinner("Analyzing file..."):
|
| 389 |
-
analysis_result = analyze_uploaded_file(uploaded_file)
|
| 390 |
-
st.session_state.analyzed_data = analysis_result
|
| 391 |
-
|
| 392 |
-
if analysis_result:
|
| 393 |
-
if "error" in analysis_result:
|
| 394 |
-
st.error(f"Error analyzing file: {analysis_result['error']}")
|
| 395 |
-
else:
|
| 396 |
-
st.success(f"Successfully analyzed {analysis_result['type']} file")
|
| 397 |
-
|
| 398 |
-
# Add file summary to chat as system message
|
| 399 |
-
file_summary = f"File analyzed: {uploaded_file.name}\n"
|
| 400 |
-
if analysis_result['type'] == 'csv' or analysis_result['type'] == 'excel':
|
| 401 |
-
file_summary += f"- {analysis_result['summary']['rows']} rows, {analysis_result['summary']['columns']} columns\n"
|
| 402 |
-
file_summary += f"- Columns: {', '.join(analysis_result['summary']['column_names'])}"
|
| 403 |
-
elif analysis_result['type'] == 'text':
|
| 404 |
-
file_summary += f"- {analysis_result['summary']['word_count']} words, {analysis_result['summary']['line_count']} lines"
|
| 405 |
-
|
| 406 |
-
# Add file summary to current chat
|
| 407 |
-
st.session_state.chats[st.session_state.current_chat].append({
|
| 408 |
-
"role": "system",
|
| 409 |
-
"content": file_summary
|
| 410 |
-
})
|
| 411 |
|
| 412 |
-
#
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
-
#
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
#
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
st.session_state.chats[st.session_state.current_chat].append({
|
| 422 |
-
"role": "user",
|
| 423 |
-
"content": user_prompt
|
| 424 |
-
})
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
st.markdown(user_prompt)
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
-
#
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
if st.session_state.analyzed_data["type"] in ["csv", "excel"]:
|
| 440 |
-
# For structured data, provide summary info
|
| 441 |
-
summary = st.session_state.analyzed_data["summary"]
|
| 442 |
-
file_context = f"""I've analyzed the uploaded {st.session_state.analyzed_data["type"]} file with {summary["rows"]} rows and {summary["columns"]} columns.
|
| 443 |
-
The columns are: {', '.join(summary["column_names"])}.
|
| 444 |
-
Here's a sample of the data (first 5 rows): {json.dumps(summary["sample"])}
|
| 445 |
-
"""
|
| 446 |
-
elif st.session_state.analyzed_data["type"] == "text":
|
| 447 |
-
# For text data, provide the content if not too large
|
| 448 |
-
summary = st.session_state.analyzed_data["summary"]
|
| 449 |
-
file_context = f"""I've analyzed the uploaded text file with {summary["word_count"]} words and {summary["line_count"]} lines.
|
| 450 |
-
Here's a preview of the content: {summary["preview"]}
|
| 451 |
-
"""
|
| 452 |
-
elif st.session_state.analyzed_data["type"] == "json":
|
| 453 |
-
# For JSON data
|
| 454 |
-
summary = st.session_state.analyzed_data["summary"]
|
| 455 |
-
file_context = f"""I've analyzed the uploaded JSON file which contains a {summary["type"]}.
|
| 456 |
-
{"Keys: " + ', '.join(summary["keys"]) if summary["keys"] else ""}
|
| 457 |
-
{"Items: " + str(summary["length"]) if summary["length"] else ""}
|
| 458 |
-
Here's a preview: {summary["preview"]}
|
| 459 |
-
"""
|
| 460 |
-
|
| 461 |
-
# Enhance the user's query with file context
|
| 462 |
-
enhanced_prompt = f"{user_prompt}\n\nContext about the file: {file_context}"
|
| 463 |
-
else:
|
| 464 |
-
enhanced_prompt = user_prompt
|
| 465 |
-
|
| 466 |
-
# Generate response
|
| 467 |
-
response = generate_response(enhanced_prompt)
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
-
#
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
-
|
| 483 |
-
|
|
|
|
| 484 |
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
if search_term:
|
| 496 |
-
filtered_data = data[data[search_col].astype(str).str.contains(search_term, case=False)]
|
| 497 |
-
st.dataframe(filtered_data)
|
| 498 |
-
|
| 499 |
-
with tab2:
|
| 500 |
-
st.write("### Numerical Statistics")
|
| 501 |
-
st.dataframe(data.describe())
|
| 502 |
-
|
| 503 |
-
st.write("### Column Information")
|
| 504 |
-
col_info = pd.DataFrame({
|
| 505 |
-
"Data Type": data.dtypes,
|
| 506 |
-
"Non-Null Count": data.count(),
|
| 507 |
-
"Null Count": data.isna().sum(),
|
| 508 |
-
"Unique Values": [data[col].nunique() for col in data.columns]
|
| 509 |
-
})
|
| 510 |
-
st.dataframe(col_info)
|
| 511 |
-
|
| 512 |
-
with tab3:
|
| 513 |
-
st.write("### Data Visualization")
|
| 514 |
-
|
| 515 |
-
# Simple chart creation
|
| 516 |
-
chart_type = st.selectbox("Chart Type", ["Bar Chart", "Line Chart", "Scatter Plot", "Histogram"])
|
| 517 |
-
|
| 518 |
-
if chart_type in ["Bar Chart", "Line Chart"]:
|
| 519 |
-
x_axis = st.selectbox("X-Axis", list(data.columns), key="x_axis_1")
|
| 520 |
-
y_axis = st.selectbox("Y-Axis", list(data.columns), key="y_axis_1")
|
| 521 |
-
|
| 522 |
-
if x_axis and y_axis:
|
| 523 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 524 |
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
ax.hist(data[column].dropna(), bins=bins)
|
| 557 |
-
plt.xlabel(column)
|
| 558 |
-
plt.ylabel("Frequency")
|
| 559 |
-
plt.title(f"Histogram of {column}")
|
| 560 |
-
plt.tight_layout()
|
| 561 |
-
st.pyplot(fig)
|
| 562 |
-
|
| 563 |
-
elif st.session_state.analyzed_data["type"] == "text":
|
| 564 |
-
text_data = st.session_state.file_data
|
| 565 |
|
| 566 |
-
#
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
|
|
|
|
|
|
| 600 |
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
st.markdown("Advanced Mistral-7B Chatbot | Built with Streamlit & Hugging Face")
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 4 |
import time
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
| 6 |
import json
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
import re
|
| 10 |
+
from threading import Thread
|
| 11 |
+
import numpy as np
|
| 12 |
from io import StringIO
|
| 13 |
|
| 14 |
+
# Global variables to store model, tokenizer and pipe
|
| 15 |
+
MODEL = None
|
| 16 |
+
TOKENIZER = None
|
| 17 |
+
PIPE = None
|
| 18 |
+
MODEL_LOADING = False
|
| 19 |
+
MODEL_LOADED = False
|
| 20 |
+
|
| 21 |
+
# Store chat history for different chat sessions
|
| 22 |
+
CHATS = {"Main Chat": []}
|
| 23 |
+
CURRENT_CHAT = "Main Chat"
|
| 24 |
|
| 25 |
+
# System prompt and generation config
|
| 26 |
+
SYSTEM_PROMPT = "You are a helpful AI assistant based on the Mistral-7B-Instruct model. You specialize in creating structured JSON data for automation workflows like n8n. When asked, format JSON properly with correct indentation and structure."
|
| 27 |
+
GENERATE_CONFIG = {
|
| 28 |
+
"max_new_tokens": 512,
|
| 29 |
+
"temperature": 0.7,
|
| 30 |
+
"top_p": 0.95,
|
| 31 |
+
"top_k": 50,
|
| 32 |
+
"repetition_penalty": 1.1,
|
| 33 |
+
"do_sample": True
|
| 34 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# File data storage
|
| 37 |
+
FILE_DATA = None
|
| 38 |
+
ANALYZED_DATA = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Function to load the model in background
|
| 41 |
def load_model_in_background():
|
| 42 |
+
global MODEL, TOKENIZER, PIPE, MODEL_LOADING, MODEL_LOADED
|
| 43 |
try:
|
| 44 |
+
MODEL_LOADING = True
|
| 45 |
+
|
| 46 |
# Model identifier
|
| 47 |
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 48 |
|
| 49 |
# Load tokenizer
|
| 50 |
+
TOKENIZER = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
| 51 |
|
| 52 |
# Configure model loading (with lower precision for efficiency)
|
| 53 |
+
MODEL = AutoModelForCausalLM.from_pretrained(
|
| 54 |
model_id,
|
| 55 |
torch_dtype=torch.float16,
|
| 56 |
device_map="auto",
|
| 57 |
low_cpu_mem_usage=True
|
| 58 |
)
|
|
|
|
| 59 |
|
| 60 |
# Create text generation pipeline
|
| 61 |
+
PIPE = pipeline(
|
| 62 |
"text-generation",
|
| 63 |
+
model=MODEL,
|
| 64 |
+
tokenizer=TOKENIZER,
|
| 65 |
return_full_text=False
|
| 66 |
)
|
| 67 |
|
| 68 |
+
MODEL_LOADING = False
|
| 69 |
+
MODEL_LOADED = True
|
| 70 |
+
return "Model loaded successfully!"
|
| 71 |
except Exception as e:
|
| 72 |
+
MODEL_LOADING = False
|
| 73 |
+
return f"Error loading model: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Function to generate response using the model
|
| 76 |
+
def generate_response(prompt, chat_history, progress=gr.Progress()):
|
| 77 |
+
global MODEL, TOKENIZER, PIPE, CHATS, CURRENT_CHAT, SYSTEM_PROMPT, GENERATE_CONFIG, FILE_DATA, ANALYZED_DATA
|
| 78 |
+
|
| 79 |
+
if not MODEL_LOADED:
|
| 80 |
return "Model is still loading. Please wait a moment before sending messages."
|
| 81 |
|
| 82 |
try:
|
| 83 |
+
# Use the current chat's history
|
| 84 |
+
messages = CHATS[CURRENT_CHAT]
|
| 85 |
|
| 86 |
# Format conversation history in Mistral's chat format
|
| 87 |
conversation = []
|
| 88 |
|
| 89 |
# Add system prompt if it exists
|
| 90 |
+
if SYSTEM_PROMPT:
|
| 91 |
+
conversation.append({"role": "system", "content": SYSTEM_PROMPT})
|
| 92 |
|
| 93 |
# Add previous messages
|
| 94 |
for msg in messages:
|
| 95 |
if msg["role"] != "system": # Skip system messages in the history
|
| 96 |
conversation.append({"role": msg["role"], "content": msg["content"]})
|
| 97 |
|
| 98 |
+
# Handle file-related queries by including context
|
| 99 |
+
if ANALYZED_DATA is not None and any(keyword in prompt.lower()
|
| 100 |
+
for keyword in ["file", "data", "analyze", "show", "tell me about", "json"]):
|
| 101 |
+
file_context = ""
|
| 102 |
+
if ANALYZED_DATA["type"] in ["csv", "excel"]:
|
| 103 |
+
# For structured data, provide summary info
|
| 104 |
+
summary = ANALYZED_DATA["summary"]
|
| 105 |
+
file_context = f"""I've analyzed the uploaded {ANALYZED_DATA["type"]} file with {summary["rows"]} rows and {summary["columns"]} columns.
|
| 106 |
+
The columns are: {', '.join(summary["column_names"])}.
|
| 107 |
+
Here's a sample of the data (first 5 rows): {json.dumps(summary["sample"])}
|
| 108 |
+
"""
|
| 109 |
+
elif ANALYZED_DATA["type"] == "text":
|
| 110 |
+
# For text data, provide the content if not too large
|
| 111 |
+
summary = ANALYZED_DATA["summary"]
|
| 112 |
+
file_context = f"""I've analyzed the uploaded text file with {summary["word_count"]} words and {summary["line_count"]} lines.
|
| 113 |
+
Here's a preview of the content: {summary["preview"]}
|
| 114 |
+
"""
|
| 115 |
+
elif ANALYZED_DATA["type"] == "json":
|
| 116 |
+
# For JSON data
|
| 117 |
+
summary = ANALYZED_DATA["summary"]
|
| 118 |
+
file_context = f"""I've analyzed the uploaded JSON file which contains a {summary["type"]}.
|
| 119 |
+
{"Keys: " + ', '.join(summary["keys"]) if summary["keys"] else ""}
|
| 120 |
+
{"Items: " + str(summary["length"]) if summary["length"] else ""}
|
| 121 |
+
Here's a preview: {summary["preview"]}
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# Enhance the user's query with file context
|
| 125 |
+
enhanced_prompt = f"{prompt}\n\nContext about the file: {file_context}"
|
| 126 |
+
else:
|
| 127 |
+
enhanced_prompt = prompt
|
| 128 |
+
|
| 129 |
# Add current prompt
|
| 130 |
+
conversation.append({"role": "user", "content": enhanced_prompt})
|
| 131 |
|
| 132 |
# Convert to Mistral's chat format
|
| 133 |
+
formatted_prompt = TOKENIZER.apply_chat_template(
|
| 134 |
conversation,
|
| 135 |
tokenize=False,
|
| 136 |
add_generation_prompt=True
|
| 137 |
)
|
| 138 |
|
| 139 |
+
# Generate response with progress reporting
|
| 140 |
+
progress(0, desc="Generating response...")
|
| 141 |
|
| 142 |
+
# Generate response
|
| 143 |
+
response = PIPE(
|
| 144 |
formatted_prompt,
|
| 145 |
+
max_new_tokens=GENERATE_CONFIG["max_new_tokens"],
|
| 146 |
+
temperature=GENERATE_CONFIG["temperature"],
|
| 147 |
+
top_p=GENERATE_CONFIG["top_p"],
|
| 148 |
+
top_k=GENERATE_CONFIG["top_k"],
|
| 149 |
+
repetition_penalty=GENERATE_CONFIG["repetition_penalty"],
|
| 150 |
+
do_sample=GENERATE_CONFIG["do_sample"]
|
| 151 |
)
|
| 152 |
|
| 153 |
+
progress(1, desc="Response generated!")
|
| 154 |
+
|
| 155 |
+
# Extract generated text
|
| 156 |
generated_text = response[0]["generated_text"]
|
| 157 |
+
|
| 158 |
+
# Add user message to chat history
|
| 159 |
+
CHATS[CURRENT_CHAT].append({
|
| 160 |
+
"role": "user",
|
| 161 |
+
"content": prompt
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# Add assistant response to chat history
|
| 165 |
+
CHATS[CURRENT_CHAT].append({
|
| 166 |
+
"role": "assistant",
|
| 167 |
+
"content": generated_text
|
| 168 |
+
})
|
| 169 |
+
|
| 170 |
+
# Update the chat history for the Gradio component
|
| 171 |
+
chat_history.append((prompt, generated_text))
|
| 172 |
+
|
| 173 |
+
return chat_history
|
| 174 |
|
| 175 |
except Exception as e:
|
| 176 |
+
error_message = f"Error generating response: {str(e)}"
|
| 177 |
+
chat_history.append((prompt, error_message))
|
| 178 |
+
return chat_history
|
| 179 |
|
| 180 |
# Function to create a new chat
|
| 181 |
def create_new_chat(chat_name):
|
| 182 |
+
global CHATS, CURRENT_CHAT
|
| 183 |
+
|
| 184 |
+
if chat_name and chat_name not in CHATS:
|
| 185 |
+
CHATS[chat_name] = []
|
| 186 |
+
CURRENT_CHAT = chat_name
|
| 187 |
+
return f"Created new chat: {chat_name}"
|
| 188 |
+
return "Please enter a unique chat name"
|
| 189 |
|
| 190 |
# Function to handle file upload and analysis
|
| 191 |
def analyze_uploaded_file(file):
|
| 192 |
+
global FILE_DATA, ANALYZED_DATA, CHATS, CURRENT_CHAT
|
|
|
|
| 193 |
|
| 194 |
+
if file is None:
|
| 195 |
+
return "No file uploaded."
|
| 196 |
|
| 197 |
try:
|
| 198 |
+
file_extension = file.name.split('.')[-1].lower()
|
| 199 |
+
|
| 200 |
if file_extension == 'csv':
|
| 201 |
+
data = pd.read_csv(file.name)
|
| 202 |
+
FILE_DATA = data
|
| 203 |
+
ANALYZED_DATA = {
|
| 204 |
"type": "csv",
|
| 205 |
"data": data,
|
| 206 |
"summary": {
|
|
|
|
| 213 |
}
|
| 214 |
|
| 215 |
elif file_extension in ['txt', 'md']:
|
| 216 |
+
with open(file.name, 'r', encoding='utf-8') as f:
|
| 217 |
+
content = f.read()
|
| 218 |
+
FILE_DATA = content
|
| 219 |
+
ANALYZED_DATA = {
|
| 220 |
"type": "text",
|
| 221 |
"data": content,
|
| 222 |
"summary": {
|
|
|
|
| 228 |
}
|
| 229 |
|
| 230 |
elif file_extension == 'json':
|
| 231 |
+
with open(file.name, 'r', encoding='utf-8') as f:
|
| 232 |
+
content = f.read()
|
| 233 |
+
data = json.loads(content)
|
| 234 |
+
FILE_DATA = data
|
| 235 |
+
ANALYZED_DATA = {
|
| 236 |
"type": "json",
|
| 237 |
"data": data,
|
| 238 |
"summary": {
|
|
|
|
| 244 |
}
|
| 245 |
|
| 246 |
elif file_extension in ['xls', 'xlsx']:
|
| 247 |
+
data = pd.read_excel(file.name)
|
| 248 |
+
FILE_DATA = data
|
| 249 |
+
ANALYZED_DATA = {
|
| 250 |
"type": "excel",
|
| 251 |
"data": data,
|
| 252 |
"summary": {
|
|
|
|
| 259 |
}
|
| 260 |
|
| 261 |
else:
|
| 262 |
+
return f"File type .{file_extension} is not supported."
|
| 263 |
+
|
| 264 |
+
# Add file summary to current chat as system message
|
| 265 |
+
file_summary = f"File analyzed: {file.name}\n"
|
| 266 |
+
if ANALYZED_DATA['type'] == 'csv' or ANALYZED_DATA['type'] == 'excel':
|
| 267 |
+
file_summary += f"- {ANALYZED_DATA['summary']['rows']} rows, {ANALYZED_DATA['summary']['columns']} columns\n"
|
| 268 |
+
file_summary += f"- Columns: {', '.join(ANALYZED_DATA['summary']['column_names'])}"
|
| 269 |
+
elif ANALYZED_DATA['type'] == 'text':
|
| 270 |
+
file_summary += f"- {ANALYZED_DATA['summary']['word_count']} words, {ANALYZED_DATA['summary']['line_count']} lines"
|
| 271 |
+
elif ANALYZED_DATA['type'] == 'json':
|
| 272 |
+
file_summary += f"- Type: {ANALYZED_DATA['summary']['type']}"
|
| 273 |
+
if ANALYZED_DATA['summary']['keys']:
|
| 274 |
+
file_summary += f"\n- Keys: {', '.join(ANALYZED_DATA['summary']['keys'])}"
|
| 275 |
+
|
| 276 |
+
# Add system message to current chat
|
| 277 |
+
CHATS[CURRENT_CHAT].append({
|
| 278 |
+
"role": "system",
|
| 279 |
+
"content": file_summary
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
return f"Successfully analyzed {ANALYZED_DATA['type']} file: {file.name}"
|
| 283 |
|
| 284 |
except Exception as e:
|
| 285 |
+
return f"Error analyzing file: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
# Function to update system prompt
|
| 288 |
+
def update_system_prompt(new_prompt):
|
| 289 |
+
global SYSTEM_PROMPT
|
| 290 |
+
SYSTEM_PROMPT = new_prompt
|
| 291 |
+
return f"System prompt updated!"
|
| 292 |
|
| 293 |
+
# Function to update generation parameters
|
| 294 |
+
def update_generation_params(temp, max_tokens, top_p, rep_penalty):
|
| 295 |
+
global GENERATE_CONFIG
|
| 296 |
+
GENERATE_CONFIG["temperature"] = temp
|
| 297 |
+
GENERATE_CONFIG["max_new_tokens"] = max_tokens
|
| 298 |
+
GENERATE_CONFIG["top_p"] = top_p
|
| 299 |
+
GENERATE_CONFIG["repetition_penalty"] = rep_penalty
|
| 300 |
+
return f"Generation parameters updated!"
|
| 301 |
|
| 302 |
+
# Function to display file data information
|
| 303 |
+
def display_file_info():
|
| 304 |
+
global ANALYZED_DATA
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
+
if ANALYZED_DATA is None:
|
| 307 |
+
return "No file has been analyzed yet."
|
|
|
|
| 308 |
|
| 309 |
+
file_info = f"## File Analysis Results\n\n"
|
| 310 |
+
file_info += f"**File Type:** {ANALYZED_DATA['type']}\n\n"
|
| 311 |
+
|
| 312 |
+
if ANALYZED_DATA['type'] in ['csv', 'excel']:
|
| 313 |
+
summary = ANALYZED_DATA['summary']
|
| 314 |
+
file_info += f"**Rows:** {summary['rows']}\n"
|
| 315 |
+
file_info += f"**Columns:** {summary['columns']}\n"
|
| 316 |
+
file_info += f"**Column Names:** {', '.join(summary['column_names'])}\n\n"
|
| 317 |
|
| 318 |
+
# Sample data preview
|
| 319 |
+
file_info += "**Sample Data (First 5 rows):**\n"
|
| 320 |
+
sample_df = pd.DataFrame(summary['sample'])
|
| 321 |
+
file_info += sample_df.to_markdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
elif ANALYZED_DATA['type'] == 'text':
|
| 324 |
+
summary = ANALYZED_DATA['summary']
|
| 325 |
+
file_info += f"**Length:** {summary['length']} characters\n"
|
| 326 |
+
file_info += f"**Word Count:** {summary['word_count']}\n"
|
| 327 |
+
file_info += f"**Line Count:** {summary['line_count']}\n\n"
|
| 328 |
+
file_info += "**Preview:**\n```\n" + summary['preview'] + "\n```"
|
| 329 |
|
| 330 |
+
elif ANALYZED_DATA['type'] == 'json':
|
| 331 |
+
summary = ANALYZED_DATA['summary']
|
| 332 |
+
file_info += f"**Type:** {summary['type']}\n"
|
| 333 |
+
if summary['keys']:
|
| 334 |
+
file_info += f"**Keys:** {', '.join(summary['keys'])}\n"
|
| 335 |
+
if summary['length'] is not None:
|
| 336 |
+
file_info += f"**Length:** {summary['length']} items\n"
|
| 337 |
+
file_info += "\n**Preview:**\n```json\n" + summary['preview'] + "\n```"
|
| 338 |
+
|
| 339 |
+
return file_info
|
| 340 |
+
|
| 341 |
+
# Function to select current chat
|
| 342 |
+
def select_chat(chat_name):
|
| 343 |
+
global CURRENT_CHAT
|
| 344 |
+
CURRENT_CHAT = chat_name
|
| 345 |
+
return f"Switched to chat: {chat_name}"
|
| 346 |
+
|
| 347 |
+
# Function to clear current chat
|
| 348 |
+
def clear_current_chat():
|
| 349 |
+
global CHATS, CURRENT_CHAT
|
| 350 |
+
CHATS[CURRENT_CHAT] = []
|
| 351 |
+
return f"Cleared chat: {CURRENT_CHAT}"
|
| 352 |
+
|
| 353 |
+
# Function to load model and return status
|
| 354 |
+
def load_model_button():
|
| 355 |
+
if MODEL_LOADED:
|
| 356 |
+
return "Model is already loaded and ready!"
|
| 357 |
+
elif MODEL_LOADING:
|
| 358 |
+
return "Model is currently loading... Please wait."
|
| 359 |
+
else:
|
| 360 |
+
thread = Thread(target=load_model_in_background)
|
| 361 |
+
thread.start()
|
| 362 |
+
return "Started loading the model. This may take a few minutes..."
|
| 363 |
+
|
| 364 |
+
# Function to get available chats
|
| 365 |
+
def get_available_chats():
|
| 366 |
+
global CHATS
|
| 367 |
+
return list(CHATS.keys())
|
| 368 |
|
| 369 |
+
# Main Gradio app
|
| 370 |
+
def create_gradio_interface():
|
| 371 |
+
# CSS for better styling
|
| 372 |
+
css = """
|
| 373 |
+
.gradio-container {max-width: 100% !important; padding: 0}
|
| 374 |
+
.chat-message-user {background-color: #e0f7fa; padding: 12px; border-radius: 8px; margin-bottom: 8px}
|
| 375 |
+
.chat-message-bot {background-color: #f1f8e9; padding: 12px; border-radius: 8px; margin-bottom: 8px}
|
| 376 |
+
.file-info {border: 1px solid #ddd; padding: 15px; border-radius: 5px; margin-top: 10px}
|
| 377 |
+
"""
|
| 378 |
|
| 379 |
+
# Setup tabs for different functionalities
|
| 380 |
+
with gr.Blocks(css=css) as app:
|
| 381 |
+
gr.Markdown("# 🤖 Advanced Mistral-7B-Instruct Chatbot")
|
| 382 |
|
| 383 |
+
with gr.Tab("Chat"):
|
| 384 |
+
with gr.Row():
|
| 385 |
+
with gr.Column(scale=3):
|
| 386 |
+
chatbot = gr.Chatbot(
|
| 387 |
+
[],
|
| 388 |
+
elem_id="chatbot",
|
| 389 |
+
height=500,
|
| 390 |
+
bubble_full_width=False,
|
| 391 |
+
avatar_images=(None, None),
|
| 392 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
with gr.Row():
|
| 395 |
+
msg = gr.Textbox(
|
| 396 |
+
placeholder="Type your message here...",
|
| 397 |
+
container=False,
|
| 398 |
+
scale=8,
|
| 399 |
+
autofocus=True
|
| 400 |
+
)
|
| 401 |
+
send_btn = gr.Button("Send", scale=1, variant="primary")
|
| 402 |
|
| 403 |
+
with gr.Row():
|
| 404 |
+
chat_selector = gr.Dropdown(
|
| 405 |
+
choices=get_available_chats(),
|
| 406 |
+
value=CURRENT_CHAT,
|
| 407 |
+
label="Select Chat",
|
| 408 |
+
interactive=True
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
with gr.Column(scale=1):
|
| 412 |
+
new_chat_name = gr.Textbox(
|
| 413 |
+
placeholder="New chat name",
|
| 414 |
+
label="Create New Chat",
|
| 415 |
+
container=True
|
| 416 |
+
)
|
| 417 |
+
with gr.Row():
|
| 418 |
+
create_chat_btn = gr.Button("Create", variant="secondary")
|
| 419 |
+
clear_chat_btn = gr.Button("Clear Current Chat", variant="secondary")
|
| 420 |
|
| 421 |
+
with gr.Column(scale=1):
|
| 422 |
+
# Model Loading and Settings
|
| 423 |
+
load_model_btn = gr.Button("Load Mistral-7B Model", variant="primary")
|
| 424 |
+
model_status = gr.Textbox(label="Model Status", value="Not loaded", interactive=False)
|
| 425 |
+
|
| 426 |
+
# System Prompt
|
| 427 |
+
system_prompt_input = gr.Textbox(
|
| 428 |
+
label="System Prompt",
|
| 429 |
+
value=SYSTEM_PROMPT,
|
| 430 |
+
lines=4,
|
| 431 |
+
placeholder="Enter system prompt to guide the AI's behavior..."
|
| 432 |
+
)
|
| 433 |
+
update_prompt_btn = gr.Button("Update System Prompt", variant="secondary")
|
| 434 |
+
|
| 435 |
+
# Model Parameters
|
| 436 |
+
gr.Markdown("### Generation Parameters")
|
| 437 |
+
temperature = gr.Slider(
|
| 438 |
+
minimum=0.1,
|
| 439 |
+
maximum=2.0,
|
| 440 |
+
value=GENERATE_CONFIG["temperature"],
|
| 441 |
+
step=0.1,
|
| 442 |
+
label="Temperature"
|
| 443 |
+
)
|
| 444 |
+
max_tokens = gr.Slider(
|
| 445 |
+
minimum=64,
|
| 446 |
+
maximum=2048,
|
| 447 |
+
value=GENERATE_CONFIG["max_new_tokens"],
|
| 448 |
+
step=64,
|
| 449 |
+
label="Max Tokens"
|
| 450 |
+
)
|
| 451 |
+
top_p = gr.Slider(
|
| 452 |
+
minimum=0.1,
|
| 453 |
+
maximum=1.0,
|
| 454 |
+
value=GENERATE_CONFIG["top_p"],
|
| 455 |
+
step=0.05,
|
| 456 |
+
label="Top P"
|
| 457 |
+
)
|
| 458 |
+
rep_penalty = gr.Slider(
|
| 459 |
+
minimum=1.0,
|
| 460 |
+
maximum=2.0,
|
| 461 |
+
value=GENERATE_CONFIG["repetition_penalty"],
|
| 462 |
+
step=0.1,
|
| 463 |
+
label="Repetition Penalty"
|
| 464 |
+
)
|
| 465 |
+
update_params_btn = gr.Button("Update Parameters", variant="secondary")
|
| 466 |
+
|
| 467 |
+
with gr.Tab("File Analysis"):
|
| 468 |
+
with gr.Row():
|
| 469 |
+
with gr.Column(scale=1):
|
| 470 |
+
file_upload = gr.File(label="Upload a file to analyze")
|
| 471 |
+
analyze_btn = gr.Button("Analyze File", variant="primary")
|
| 472 |
|
| 473 |
+
with gr.Column(scale=2):
|
| 474 |
+
file_analysis_output = gr.Markdown(label="File Analysis Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
+
# Set up event handlers
|
| 477 |
+
send_btn.click(
|
| 478 |
+
generate_response,
|
| 479 |
+
inputs=[msg, chatbot],
|
| 480 |
+
outputs=[chatbot],
|
| 481 |
+
api_name="chat"
|
| 482 |
+
)
|
| 483 |
|
| 484 |
+
msg.submit(
|
| 485 |
+
generate_response,
|
| 486 |
+
inputs=[msg, chatbot],
|
| 487 |
+
outputs=[chatbot],
|
| 488 |
+
api_name=False
|
| 489 |
+
)
|
| 490 |
|
| 491 |
+
load_model_btn.click(
|
| 492 |
+
load_model_button,
|
| 493 |
+
outputs=model_status,
|
| 494 |
+
api_name="load_model"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
update_prompt_btn.click(
|
| 498 |
+
update_system_prompt,
|
| 499 |
+
inputs=system_prompt_input,
|
| 500 |
+
outputs=model_status,
|
| 501 |
+
api_name="update_prompt"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
update_params_btn.click(
|
| 505 |
+
update_generation_params,
|
| 506 |
+
inputs=[temperature, max_tokens, top_p, rep_penalty],
|
| 507 |
+
outputs=model_status,
|
| 508 |
+
api_name="update_params"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
analyze_btn.click(
|
| 512 |
+
analyze_uploaded_file,
|
| 513 |
+
inputs=file_upload,
|
| 514 |
+
outputs=model_status,
|
| 515 |
+
api_name="analyze_file"
|
| 516 |
+
).then(
|
| 517 |
+
display_file_info,
|
| 518 |
+
outputs=file_analysis_output,
|
| 519 |
+
api_name=False
|
| 520 |
+
)
|
| 521 |
|
| 522 |
+
chat_selector.change(
|
| 523 |
+
select_chat,
|
| 524 |
+
inputs=chat_selector,
|
| 525 |
+
outputs=model_status,
|
| 526 |
+
api_name="select_chat"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
create_chat_btn.click(
|
| 530 |
+
create_new_chat,
|
| 531 |
+
inputs=new_chat_name,
|
| 532 |
+
outputs=model_status,
|
| 533 |
+
api_name="create_chat"
|
| 534 |
+
).then(
|
| 535 |
+
get_available_chats,
|
| 536 |
+
outputs=chat_selector,
|
| 537 |
+
api_name=False
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
clear_chat_btn.click(
|
| 541 |
+
clear_current_chat,
|
| 542 |
+
outputs=model_status,
|
| 543 |
+
api_name="clear_chat"
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Refresh example messages every few seconds
|
| 547 |
+
chat_selector.change(lambda: [], outputs=chatbot)
|
| 548 |
+
|
| 549 |
+
return app
|
| 550 |
+
|
| 551 |
+
# Launch the app
|
| 552 |
+
demo = create_gradio_interface()
|
| 553 |
|
| 554 |
+
if __name__ == "__main__":
|
| 555 |
+
demo.launch()
|
|
|