shaheerawan3 commited on
Commit
8b06363
·
verified ·
1 Parent(s): 06ac293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -465
app.py CHANGED
@@ -1,207 +1,206 @@
1
- import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import time
5
  import os
6
- from threading import Thread
7
- import matplotlib.pyplot as plt
8
- import pandas as pd
9
  import json
 
 
10
  import re
 
 
11
  from io import StringIO
12
 
13
- # Set page configuration
14
- st.set_page_config(
15
- page_title="Advanced Mistral-7B Chatbot",
16
- page_icon="🤖",
17
- layout="wide",
18
- initial_sidebar_state="expanded"
19
- )
 
 
 
20
 
21
- # Custom CSS for better appearance
22
- st.markdown("""
23
- <style>
24
- .chat-message {
25
- padding: 1.5rem;
26
- border-radius: 0.8rem;
27
- margin-bottom: 1rem;
28
- display: flex;
29
- flex-direction: column;
30
- }
31
- .chat-message.user {
32
- background-color: #e0f7fa;
33
- border-left: 5px solid #039be5;
34
- }
35
- .chat-message.assistant {
36
- background-color: #f1f8e9;
37
- border-left: 5px solid #7cb342;
38
- }
39
- .chat-message .content {
40
- margin-top: 0.5rem;
41
- }
42
- .sidebar-content {
43
- padding: 1rem;
44
- }
45
- .sidebar-header {
46
- font-weight: bold;
47
- margin-bottom: 0.5rem;
48
- }
49
- </style>
50
- """, unsafe_allow_html=True)
51
 
52
- # Initialize session state variables
53
- if 'messages' not in st.session_state:
54
- st.session_state.messages = []
55
- if 'model' not in st.session_state:
56
- st.session_state.model = None
57
- if 'tokenizer' not in st.session_state:
58
- st.session_state.tokenizer = None
59
- if 'pipe' not in st.session_state:
60
- st.session_state.pipe = None
61
- if 'loading_model' not in st.session_state:
62
- st.session_state.loading_model = False
63
- if 'model_loaded' not in st.session_state:
64
- st.session_state.model_loaded = False
65
- if 'system_prompt' not in st.session_state:
66
- st.session_state.system_prompt = "You are a helpful AI assistant based on the Mistral-7B-Instruct model. Answer the user's questions to the best of your ability."
67
- if 'generate_config' not in st.session_state:
68
- st.session_state.generate_config = {
69
- "max_new_tokens": 512,
70
- "temperature": 0.7,
71
- "top_p": 0.95,
72
- "top_k": 50,
73
- "repetition_penalty": 1.1,
74
- "do_sample": True
75
- }
76
- if 'chats' not in st.session_state:
77
- st.session_state.chats = {"Main Chat": []}
78
- if 'current_chat' not in st.session_state:
79
- st.session_state.current_chat = "Main Chat"
80
- if 'file_data' not in st.session_state:
81
- st.session_state.file_data = None
82
- if 'analyzed_data' not in st.session_state:
83
- st.session_state.analyzed_data = None
84
 
85
  # Function to load the model in background
86
  def load_model_in_background():
 
87
  try:
88
- st.session_state.loading_model = True
 
89
  # Model identifier
90
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
91
 
92
  # Load tokenizer
93
- tokenizer = AutoTokenizer.from_pretrained(model_id)
94
- st.session_state.tokenizer = tokenizer
95
 
96
  # Configure model loading (with lower precision for efficiency)
97
- model = AutoModelForCausalLM.from_pretrained(
98
  model_id,
99
  torch_dtype=torch.float16,
100
  device_map="auto",
101
  low_cpu_mem_usage=True
102
  )
103
- st.session_state.model = model
104
 
105
  # Create text generation pipeline
106
- st.session_state.pipe = pipeline(
107
  "text-generation",
108
- model=model,
109
- tokenizer=tokenizer,
110
  return_full_text=False
111
  )
112
 
113
- st.session_state.loading_model = False
114
- st.session_state.model_loaded = True
 
115
  except Exception as e:
116
- st.session_state.loading_model = False
117
- st.error(f"Error loading model: {str(e)}")
118
-
119
- # Function to format and display chat messages
120
- def display_messages():
121
- for message in st.session_state.chats[st.session_state.current_chat]:
122
- with st.container():
123
- col1, col2 = st.columns([1, 12])
124
-
125
- if message["role"] == "user":
126
- with st.chat_message("user"):
127
- st.markdown(message["content"])
128
- elif message["role"] == "assistant":
129
- with st.chat_message("assistant"):
130
- st.markdown(message["content"])
131
- elif message["role"] == "system":
132
- with st.chat_message("system"):
133
- st.markdown(f"🔧 **System:** {message['content']}")
134
 
135
  # Function to generate response using the model
136
- def generate_response(prompt):
137
- if not st.session_state.model_loaded:
 
 
138
  return "Model is still loading. Please wait a moment before sending messages."
139
 
140
  try:
141
- messages = st.session_state.chats[st.session_state.current_chat]
 
142
 
143
  # Format conversation history in Mistral's chat format
144
  conversation = []
145
 
146
  # Add system prompt if it exists
147
- if st.session_state.system_prompt:
148
- conversation.append({"role": "system", "content": st.session_state.system_prompt})
149
 
150
  # Add previous messages
151
  for msg in messages:
152
  if msg["role"] != "system": # Skip system messages in the history
153
  conversation.append({"role": msg["role"], "content": msg["content"]})
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Add current prompt
156
- conversation.append({"role": "user", "content": prompt})
157
 
158
  # Convert to Mistral's chat format
159
- formatted_prompt = st.session_state.tokenizer.apply_chat_template(
160
  conversation,
161
  tokenize=False,
162
  add_generation_prompt=True
163
  )
164
 
165
- # Generate response
166
- generation_config = st.session_state.generate_config
167
 
168
- response = st.session_state.pipe(
 
169
  formatted_prompt,
170
- max_new_tokens=generation_config["max_new_tokens"],
171
- temperature=generation_config["temperature"],
172
- top_p=generation_config["top_p"],
173
- top_k=generation_config["top_k"],
174
- repetition_penalty=generation_config["repetition_penalty"],
175
- do_sample=generation_config["do_sample"]
176
  )
177
 
178
- # Extract and return generated text
 
 
179
  generated_text = response[0]["generated_text"]
180
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  except Exception as e:
183
- return f"Error generating response: {str(e)}"
 
 
184
 
185
  # Function to create a new chat
186
  def create_new_chat(chat_name):
187
- if chat_name and chat_name not in st.session_state.chats:
188
- st.session_state.chats[chat_name] = []
189
- st.session_state.current_chat = chat_name
190
- return True
191
- return False
 
 
192
 
193
  # Function to handle file upload and analysis
194
  def analyze_uploaded_file(file):
195
- if file is None:
196
- return None
197
 
198
- file_extension = file.name.split('.')[-1].lower()
 
199
 
200
  try:
 
 
201
  if file_extension == 'csv':
202
- data = pd.read_csv(file)
203
- st.session_state.file_data = data
204
- return {
205
  "type": "csv",
206
  "data": data,
207
  "summary": {
@@ -214,9 +213,10 @@ def analyze_uploaded_file(file):
214
  }
215
 
216
  elif file_extension in ['txt', 'md']:
217
- content = file.getvalue().decode('utf-8')
218
- st.session_state.file_data = content
219
- return {
 
220
  "type": "text",
221
  "data": content,
222
  "summary": {
@@ -228,10 +228,11 @@ def analyze_uploaded_file(file):
228
  }
229
 
230
  elif file_extension == 'json':
231
- content = file.getvalue().decode('utf-8')
232
- data = json.loads(content)
233
- st.session_state.file_data = data
234
- return {
 
235
  "type": "json",
236
  "data": data,
237
  "summary": {
@@ -243,9 +244,9 @@ def analyze_uploaded_file(file):
243
  }
244
 
245
  elif file_extension in ['xls', 'xlsx']:
246
- data = pd.read_excel(file)
247
- st.session_state.file_data = data
248
- return {
249
  "type": "excel",
250
  "data": data,
251
  "summary": {
@@ -258,349 +259,297 @@ def analyze_uploaded_file(file):
258
  }
259
 
260
  else:
261
- return {
262
- "type": "unsupported",
263
- "error": f"File type .{file_extension} is not supported."
264
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  except Exception as e:
267
- return {
268
- "type": "error",
269
- "error": str(e)
270
- }
271
-
272
- # Sidebar - Model Control and Settings
273
- with st.sidebar:
274
- st.title("🤖 Mistral-7B Chatbot")
275
-
276
- # Model loading section
277
- st.header("Model Control")
278
- if not st.session_state.model_loaded and not st.session_state.loading_model:
279
- if st.button("Load Mistral-7B Model"):
280
- loading_thread = Thread(target=load_model_in_background)
281
- loading_thread.start()
282
-
283
- if st.session_state.loading_model:
284
- st.info("Loading model... This may take a few minutes.")
285
- progress_bar = st.progress(0)
286
- for i in range(100):
287
- if not st.session_state.loading_model:
288
- progress_bar.progress(100)
289
- st.success("Model loaded successfully!")
290
- break
291
- progress_bar.progress(i + 1)
292
- time.sleep(0.1)
293
-
294
- if st.session_state.model_loaded:
295
- st.success("Model loaded and ready!")
296
-
297
- # Chat management
298
- st.header("Chat Management")
299
-
300
- # Select chat
301
- selected_chat = st.selectbox(
302
- "Select Chat",
303
- options=list(st.session_state.chats.keys()),
304
- index=list(st.session_state.chats.keys()).index(st.session_state.current_chat)
305
- )
306
-
307
- if selected_chat != st.session_state.current_chat:
308
- st.session_state.current_chat = selected_chat
309
-
310
- # Create new chat
311
- new_chat_name = st.text_input("New Chat Name")
312
- if st.button("Create New Chat"):
313
- if create_new_chat(new_chat_name):
314
- st.success(f"Created new chat: {new_chat_name}")
315
- else:
316
- st.error("Please enter a unique chat name")
317
-
318
- # Clear current chat
319
- if st.button("Clear Current Chat"):
320
- st.session_state.chats[st.session_state.current_chat] = []
321
- st.success(f"Cleared chat: {st.session_state.current_chat}")
322
-
323
- # Advanced settings
324
- st.header("Model Settings")
325
-
326
- with st.expander("System Prompt"):
327
- system_prompt = st.text_area(
328
- "Set the AI's behavior",
329
- value=st.session_state.system_prompt,
330
- height=100
331
- )
332
- if st.button("Update System Prompt"):
333
- st.session_state.system_prompt = system_prompt
334
- st.success("System prompt updated!")
335
-
336
- with st.expander("Generation Parameters"):
337
- st.slider(
338
- "Temperature",
339
- min_value=0.1,
340
- max_value=2.0,
341
- value=st.session_state.generate_config["temperature"],
342
- step=0.1,
343
- key="temp_slider",
344
- help="Higher = more creative, Lower = more focused"
345
- )
346
- st.session_state.generate_config["temperature"] = st.session_state["temp_slider"]
347
-
348
- st.slider(
349
- "Max Tokens",
350
- min_value=64,
351
- max_value=2048,
352
- value=st.session_state.generate_config["max_new_tokens"],
353
- step=64,
354
- key="max_tokens_slider",
355
- help="Maximum number of tokens in the response"
356
- )
357
- st.session_state.generate_config["max_new_tokens"] = st.session_state["max_tokens_slider"]
358
-
359
- st.slider(
360
- "Top P",
361
- min_value=0.1,
362
- max_value=1.0,
363
- value=st.session_state.generate_config["top_p"],
364
- step=0.05,
365
- key="top_p_slider",
366
- help="Nucleus sampling parameter"
367
- )
368
- st.session_state.generate_config["top_p"] = st.session_state["top_p_slider"]
369
-
370
- st.slider(
371
- "Repetition Penalty",
372
- min_value=1.0,
373
- max_value=2.0,
374
- value=st.session_state.generate_config["repetition_penalty"],
375
- step=0.1,
376
- key="rep_pen_slider",
377
- help="Penalizes repetition (higher = less repetition)"
378
- )
379
- st.session_state.generate_config["repetition_penalty"] = st.session_state["rep_pen_slider"]
380
-
381
- # File upload and analysis
382
- st.header("File Analysis")
383
- uploaded_file = st.file_uploader("Upload a file to analyze",
384
- type=["csv", "txt", "json", "xlsx", "xls", "md"])
385
-
386
- if uploaded_file is not None:
387
- if st.button("Analyze File"):
388
- with st.spinner("Analyzing file..."):
389
- analysis_result = analyze_uploaded_file(uploaded_file)
390
- st.session_state.analyzed_data = analysis_result
391
-
392
- if analysis_result:
393
- if "error" in analysis_result:
394
- st.error(f"Error analyzing file: {analysis_result['error']}")
395
- else:
396
- st.success(f"Successfully analyzed {analysis_result['type']} file")
397
-
398
- # Add file summary to chat as system message
399
- file_summary = f"File analyzed: {uploaded_file.name}\n"
400
- if analysis_result['type'] == 'csv' or analysis_result['type'] == 'excel':
401
- file_summary += f"- {analysis_result['summary']['rows']} rows, {analysis_result['summary']['columns']} columns\n"
402
- file_summary += f"- Columns: {', '.join(analysis_result['summary']['column_names'])}"
403
- elif analysis_result['type'] == 'text':
404
- file_summary += f"- {analysis_result['summary']['word_count']} words, {analysis_result['summary']['line_count']} lines"
405
-
406
- # Add file summary to current chat
407
- st.session_state.chats[st.session_state.current_chat].append({
408
- "role": "system",
409
- "content": file_summary
410
- })
411
 
412
- # Main chat interface
413
- st.header(f"💬 Chat: {st.session_state.current_chat}")
 
 
 
414
 
415
- # Display chat messages
416
- display_messages()
 
 
 
 
 
 
417
 
418
- # User input
419
- if user_prompt := st.chat_input("Type your message here...", disabled=not st.session_state.model_loaded and not st.session_state.loading_model):
420
- # Add user message to chat history
421
- st.session_state.chats[st.session_state.current_chat].append({
422
- "role": "user",
423
- "content": user_prompt
424
- })
425
 
426
- # Display user message
427
- with st.chat_message("user"):
428
- st.markdown(user_prompt)
429
 
430
- # Generate and display assistant response
431
- with st.chat_message("assistant"):
432
- message_placeholder = st.empty()
433
- message_placeholder.markdown("Thinking...")
 
 
 
 
434
 
435
- # Handle file-related queries by including context
436
- if st.session_state.analyzed_data is not None and any(keyword in user_prompt.lower()
437
- for keyword in ["file", "data", "analyze", "show", "tell me about"]):
438
- file_context = ""
439
- if st.session_state.analyzed_data["type"] in ["csv", "excel"]:
440
- # For structured data, provide summary info
441
- summary = st.session_state.analyzed_data["summary"]
442
- file_context = f"""I've analyzed the uploaded {st.session_state.analyzed_data["type"]} file with {summary["rows"]} rows and {summary["columns"]} columns.
443
- The columns are: {', '.join(summary["column_names"])}.
444
- Here's a sample of the data (first 5 rows): {json.dumps(summary["sample"])}
445
- """
446
- elif st.session_state.analyzed_data["type"] == "text":
447
- # For text data, provide the content if not too large
448
- summary = st.session_state.analyzed_data["summary"]
449
- file_context = f"""I've analyzed the uploaded text file with {summary["word_count"]} words and {summary["line_count"]} lines.
450
- Here's a preview of the content: {summary["preview"]}
451
- """
452
- elif st.session_state.analyzed_data["type"] == "json":
453
- # For JSON data
454
- summary = st.session_state.analyzed_data["summary"]
455
- file_context = f"""I've analyzed the uploaded JSON file which contains a {summary["type"]}.
456
- {"Keys: " + ', '.join(summary["keys"]) if summary["keys"] else ""}
457
- {"Items: " + str(summary["length"]) if summary["length"] else ""}
458
- Here's a preview: {summary["preview"]}
459
- """
460
-
461
- # Enhance the user's query with file context
462
- enhanced_prompt = f"{user_prompt}\n\nContext about the file: {file_context}"
463
- else:
464
- enhanced_prompt = user_prompt
465
-
466
- # Generate response
467
- response = generate_response(enhanced_prompt)
468
 
469
- # Update the placeholder with the complete response
470
- message_placeholder.markdown(response)
 
 
 
 
471
 
472
- # Add assistant response to chat history
473
- st.session_state.chats[st.session_state.current_chat].append({
474
- "role": "assistant",
475
- "content": response
476
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- # Display data analysis tools when file is uploaded
479
- if st.session_state.analyzed_data is not None and st.session_state.file_data is not None:
480
- st.header("📊 Data Analysis")
 
 
 
 
 
 
481
 
482
- if st.session_state.analyzed_data["type"] in ["csv", "excel"]:
483
- data = st.session_state.file_data
 
484
 
485
- # Data exploration tabs
486
- tab1, tab2, tab3 = st.tabs(["Data Preview", "Statistics", "Visualization"])
487
-
488
- with tab1:
489
- st.dataframe(data.head(10))
490
-
491
- # Simple search/filter
492
- search_col = st.selectbox("Search column", [""] + list(data.columns))
493
- if search_col:
494
- search_term = st.text_input("Search term")
495
- if search_term:
496
- filtered_data = data[data[search_col].astype(str).str.contains(search_term, case=False)]
497
- st.dataframe(filtered_data)
498
-
499
- with tab2:
500
- st.write("### Numerical Statistics")
501
- st.dataframe(data.describe())
502
-
503
- st.write("### Column Information")
504
- col_info = pd.DataFrame({
505
- "Data Type": data.dtypes,
506
- "Non-Null Count": data.count(),
507
- "Null Count": data.isna().sum(),
508
- "Unique Values": [data[col].nunique() for col in data.columns]
509
- })
510
- st.dataframe(col_info)
511
-
512
- with tab3:
513
- st.write("### Data Visualization")
514
-
515
- # Simple chart creation
516
- chart_type = st.selectbox("Chart Type", ["Bar Chart", "Line Chart", "Scatter Plot", "Histogram"])
517
-
518
- if chart_type in ["Bar Chart", "Line Chart"]:
519
- x_axis = st.selectbox("X-Axis", list(data.columns), key="x_axis_1")
520
- y_axis = st.selectbox("Y-Axis", list(data.columns), key="y_axis_1")
521
-
522
- if x_axis and y_axis:
523
- fig, ax = plt.subplots(figsize=(10, 6))
524
 
525
- if chart_type == "Bar Chart":
526
- data.groupby(x_axis)[y_axis].mean().plot(kind='bar', ax=ax)
527
- else: # Line Chart
528
- data.groupby(x_axis)[y_axis].mean().plot(kind='line', ax=ax)
 
 
 
 
529
 
530
- plt.xlabel(x_axis)
531
- plt.ylabel(y_axis)
532
- plt.title(f"{chart_type} of {y_axis} by {x_axis}")
533
- plt.xticks(rotation=45)
534
- plt.tight_layout()
535
- st.pyplot(fig)
536
-
537
- elif chart_type == "Scatter Plot":
538
- x_axis = st.selectbox("X-Axis", list(data.columns), key="x_axis_2")
539
- y_axis = st.selectbox("Y-Axis", list(data.columns), key="y_axis_2")
 
 
 
 
 
 
 
540
 
541
- if x_axis and y_axis:
542
- fig, ax = plt.subplots(figsize=(10, 6))
543
- ax.scatter(data[x_axis], data[y_axis])
544
- plt.xlabel(x_axis)
545
- plt.ylabel(y_axis)
546
- plt.title(f"Scatter Plot of {y_axis} vs {x_axis}")
547
- plt.tight_layout()
548
- st.pyplot(fig)
549
-
550
- elif chart_type == "Histogram":
551
- column = st.selectbox("Column", list(data.columns))
552
- bins = st.slider("Bins", min_value=5, max_value=100, value=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
- if column:
555
- fig, ax = plt.subplots(figsize=(10, 6))
556
- ax.hist(data[column].dropna(), bins=bins)
557
- plt.xlabel(column)
558
- plt.ylabel("Frequency")
559
- plt.title(f"Histogram of {column}")
560
- plt.tight_layout()
561
- st.pyplot(fig)
562
-
563
- elif st.session_state.analyzed_data["type"] == "text":
564
- text_data = st.session_state.file_data
565
 
566
- # Text analysis tools
567
- tab1, tab2 = st.tabs(["Text Preview", "Text Analysis"])
 
 
 
 
 
568
 
569
- with tab1:
570
- st.text_area("Content", text_data, height=300)
 
 
 
 
571
 
572
- with tab2:
573
- # Basic text analysis
574
- word_count = len(text_data.split())
575
- char_count = len(text_data)
576
- line_count = len(text_data.splitlines())
577
-
578
- col1, col2, col3 = st.columns(3)
579
- col1.metric("Word Count", word_count)
580
- col2.metric("Character Count", char_count)
581
- col3.metric("Line Count", line_count)
582
-
583
- # Word frequency
584
- if st.checkbox("Show Word Frequency"):
585
- words = re.findall(r'\b\w+\b', text_data.lower())
586
- word_freq = pd.Series(words).value_counts().head(20)
587
-
588
- st.write("### Most Common Words")
589
- fig, ax = plt.subplots(figsize=(10, 6))
590
- word_freq.plot(kind='bar', ax=ax)
591
- plt.xlabel("Word")
592
- plt.ylabel("Frequency")
593
- plt.title("Top 20 Most Common Words")
594
- plt.xticks(rotation=45)
595
- plt.tight_layout()
596
- st.pyplot(fig)
597
-
598
- elif st.session_state.analyzed_data["type"] == "json":
599
- json_data = st.session_state.file_data
 
 
600
 
601
- # JSON explorer
602
- st.json(json_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
- # Footer
605
- st.markdown("---")
606
- st.markdown("Advanced Mistral-7B Chatbot | Built with Streamlit & Hugging Face")
 
1
+ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import time
5
  import os
 
 
 
6
  import json
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
  import re
10
+ from threading import Thread
11
+ import numpy as np
12
  from io import StringIO
13
 
14
+ # Global variables to store model, tokenizer and pipe
15
+ MODEL = None
16
+ TOKENIZER = None
17
+ PIPE = None
18
+ MODEL_LOADING = False
19
+ MODEL_LOADED = False
20
+
21
+ # Store chat history for different chat sessions
22
+ CHATS = {"Main Chat": []}
23
+ CURRENT_CHAT = "Main Chat"
24
 
25
+ # System prompt and generation config
26
+ SYSTEM_PROMPT = "You are a helpful AI assistant based on the Mistral-7B-Instruct model. You specialize in creating structured JSON data for automation workflows like n8n. When asked, format JSON properly with correct indentation and structure."
27
+ GENERATE_CONFIG = {
28
+ "max_new_tokens": 512,
29
+ "temperature": 0.7,
30
+ "top_p": 0.95,
31
+ "top_k": 50,
32
+ "repetition_penalty": 1.1,
33
+ "do_sample": True
34
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # File data storage
37
+ FILE_DATA = None
38
+ ANALYZED_DATA = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Function to load the model in background
41
  def load_model_in_background():
42
+ global MODEL, TOKENIZER, PIPE, MODEL_LOADING, MODEL_LOADED
43
  try:
44
+ MODEL_LOADING = True
45
+
46
  # Model identifier
47
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
48
 
49
  # Load tokenizer
50
+ TOKENIZER = AutoTokenizer.from_pretrained(model_id)
 
51
 
52
  # Configure model loading (with lower precision for efficiency)
53
+ MODEL = AutoModelForCausalLM.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.float16,
56
  device_map="auto",
57
  low_cpu_mem_usage=True
58
  )
 
59
 
60
  # Create text generation pipeline
61
+ PIPE = pipeline(
62
  "text-generation",
63
+ model=MODEL,
64
+ tokenizer=TOKENIZER,
65
  return_full_text=False
66
  )
67
 
68
+ MODEL_LOADING = False
69
+ MODEL_LOADED = True
70
+ return "Model loaded successfully!"
71
  except Exception as e:
72
+ MODEL_LOADING = False
73
+ return f"Error loading model: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Function to generate response using the model
76
+ def generate_response(prompt, chat_history, progress=gr.Progress()):
77
+ global MODEL, TOKENIZER, PIPE, CHATS, CURRENT_CHAT, SYSTEM_PROMPT, GENERATE_CONFIG, FILE_DATA, ANALYZED_DATA
78
+
79
+ if not MODEL_LOADED:
80
  return "Model is still loading. Please wait a moment before sending messages."
81
 
82
  try:
83
+ # Use the current chat's history
84
+ messages = CHATS[CURRENT_CHAT]
85
 
86
  # Format conversation history in Mistral's chat format
87
  conversation = []
88
 
89
  # Add system prompt if it exists
90
+ if SYSTEM_PROMPT:
91
+ conversation.append({"role": "system", "content": SYSTEM_PROMPT})
92
 
93
  # Add previous messages
94
  for msg in messages:
95
  if msg["role"] != "system": # Skip system messages in the history
96
  conversation.append({"role": msg["role"], "content": msg["content"]})
97
 
98
+ # Handle file-related queries by including context
99
+ if ANALYZED_DATA is not None and any(keyword in prompt.lower()
100
+ for keyword in ["file", "data", "analyze", "show", "tell me about", "json"]):
101
+ file_context = ""
102
+ if ANALYZED_DATA["type"] in ["csv", "excel"]:
103
+ # For structured data, provide summary info
104
+ summary = ANALYZED_DATA["summary"]
105
+ file_context = f"""I've analyzed the uploaded {ANALYZED_DATA["type"]} file with {summary["rows"]} rows and {summary["columns"]} columns.
106
+ The columns are: {', '.join(summary["column_names"])}.
107
+ Here's a sample of the data (first 5 rows): {json.dumps(summary["sample"])}
108
+ """
109
+ elif ANALYZED_DATA["type"] == "text":
110
+ # For text data, provide the content if not too large
111
+ summary = ANALYZED_DATA["summary"]
112
+ file_context = f"""I've analyzed the uploaded text file with {summary["word_count"]} words and {summary["line_count"]} lines.
113
+ Here's a preview of the content: {summary["preview"]}
114
+ """
115
+ elif ANALYZED_DATA["type"] == "json":
116
+ # For JSON data
117
+ summary = ANALYZED_DATA["summary"]
118
+ file_context = f"""I've analyzed the uploaded JSON file which contains a {summary["type"]}.
119
+ {"Keys: " + ', '.join(summary["keys"]) if summary["keys"] else ""}
120
+ {"Items: " + str(summary["length"]) if summary["length"] else ""}
121
+ Here's a preview: {summary["preview"]}
122
+ """
123
+
124
+ # Enhance the user's query with file context
125
+ enhanced_prompt = f"{prompt}\n\nContext about the file: {file_context}"
126
+ else:
127
+ enhanced_prompt = prompt
128
+
129
  # Add current prompt
130
+ conversation.append({"role": "user", "content": enhanced_prompt})
131
 
132
  # Convert to Mistral's chat format
133
+ formatted_prompt = TOKENIZER.apply_chat_template(
134
  conversation,
135
  tokenize=False,
136
  add_generation_prompt=True
137
  )
138
 
139
+ # Generate response with progress reporting
140
+ progress(0, desc="Generating response...")
141
 
142
+ # Generate response
143
+ response = PIPE(
144
  formatted_prompt,
145
+ max_new_tokens=GENERATE_CONFIG["max_new_tokens"],
146
+ temperature=GENERATE_CONFIG["temperature"],
147
+ top_p=GENERATE_CONFIG["top_p"],
148
+ top_k=GENERATE_CONFIG["top_k"],
149
+ repetition_penalty=GENERATE_CONFIG["repetition_penalty"],
150
+ do_sample=GENERATE_CONFIG["do_sample"]
151
  )
152
 
153
+ progress(1, desc="Response generated!")
154
+
155
+ # Extract generated text
156
  generated_text = response[0]["generated_text"]
157
+
158
+ # Add user message to chat history
159
+ CHATS[CURRENT_CHAT].append({
160
+ "role": "user",
161
+ "content": prompt
162
+ })
163
+
164
+ # Add assistant response to chat history
165
+ CHATS[CURRENT_CHAT].append({
166
+ "role": "assistant",
167
+ "content": generated_text
168
+ })
169
+
170
+ # Update the chat history for the Gradio component
171
+ chat_history.append((prompt, generated_text))
172
+
173
+ return chat_history
174
 
175
  except Exception as e:
176
+ error_message = f"Error generating response: {str(e)}"
177
+ chat_history.append((prompt, error_message))
178
+ return chat_history
179
 
180
  # Function to create a new chat
181
  def create_new_chat(chat_name):
182
+ global CHATS, CURRENT_CHAT
183
+
184
+ if chat_name and chat_name not in CHATS:
185
+ CHATS[chat_name] = []
186
+ CURRENT_CHAT = chat_name
187
+ return f"Created new chat: {chat_name}"
188
+ return "Please enter a unique chat name"
189
 
190
  # Function to handle file upload and analysis
191
  def analyze_uploaded_file(file):
192
+ global FILE_DATA, ANALYZED_DATA, CHATS, CURRENT_CHAT
 
193
 
194
+ if file is None:
195
+ return "No file uploaded."
196
 
197
  try:
198
+ file_extension = file.name.split('.')[-1].lower()
199
+
200
  if file_extension == 'csv':
201
+ data = pd.read_csv(file.name)
202
+ FILE_DATA = data
203
+ ANALYZED_DATA = {
204
  "type": "csv",
205
  "data": data,
206
  "summary": {
 
213
  }
214
 
215
  elif file_extension in ['txt', 'md']:
216
+ with open(file.name, 'r', encoding='utf-8') as f:
217
+ content = f.read()
218
+ FILE_DATA = content
219
+ ANALYZED_DATA = {
220
  "type": "text",
221
  "data": content,
222
  "summary": {
 
228
  }
229
 
230
  elif file_extension == 'json':
231
+ with open(file.name, 'r', encoding='utf-8') as f:
232
+ content = f.read()
233
+ data = json.loads(content)
234
+ FILE_DATA = data
235
+ ANALYZED_DATA = {
236
  "type": "json",
237
  "data": data,
238
  "summary": {
 
244
  }
245
 
246
  elif file_extension in ['xls', 'xlsx']:
247
+ data = pd.read_excel(file.name)
248
+ FILE_DATA = data
249
+ ANALYZED_DATA = {
250
  "type": "excel",
251
  "data": data,
252
  "summary": {
 
259
  }
260
 
261
  else:
262
+ return f"File type .{file_extension} is not supported."
263
+
264
+ # Add file summary to current chat as system message
265
+ file_summary = f"File analyzed: {file.name}\n"
266
+ if ANALYZED_DATA['type'] == 'csv' or ANALYZED_DATA['type'] == 'excel':
267
+ file_summary += f"- {ANALYZED_DATA['summary']['rows']} rows, {ANALYZED_DATA['summary']['columns']} columns\n"
268
+ file_summary += f"- Columns: {', '.join(ANALYZED_DATA['summary']['column_names'])}"
269
+ elif ANALYZED_DATA['type'] == 'text':
270
+ file_summary += f"- {ANALYZED_DATA['summary']['word_count']} words, {ANALYZED_DATA['summary']['line_count']} lines"
271
+ elif ANALYZED_DATA['type'] == 'json':
272
+ file_summary += f"- Type: {ANALYZED_DATA['summary']['type']}"
273
+ if ANALYZED_DATA['summary']['keys']:
274
+ file_summary += f"\n- Keys: {', '.join(ANALYZED_DATA['summary']['keys'])}"
275
+
276
+ # Add system message to current chat
277
+ CHATS[CURRENT_CHAT].append({
278
+ "role": "system",
279
+ "content": file_summary
280
+ })
281
+
282
+ return f"Successfully analyzed {ANALYZED_DATA['type']} file: {file.name}"
283
 
284
  except Exception as e:
285
+ return f"Error analyzing file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ # Function to update system prompt
288
+ def update_system_prompt(new_prompt):
289
+ global SYSTEM_PROMPT
290
+ SYSTEM_PROMPT = new_prompt
291
+ return f"System prompt updated!"
292
 
293
+ # Function to update generation parameters
294
+ def update_generation_params(temp, max_tokens, top_p, rep_penalty):
295
+ global GENERATE_CONFIG
296
+ GENERATE_CONFIG["temperature"] = temp
297
+ GENERATE_CONFIG["max_new_tokens"] = max_tokens
298
+ GENERATE_CONFIG["top_p"] = top_p
299
+ GENERATE_CONFIG["repetition_penalty"] = rep_penalty
300
+ return f"Generation parameters updated!"
301
 
302
+ # Function to display file data information
303
+ def display_file_info():
304
+ global ANALYZED_DATA
 
 
 
 
305
 
306
+ if ANALYZED_DATA is None:
307
+ return "No file has been analyzed yet."
 
308
 
309
+ file_info = f"## File Analysis Results\n\n"
310
+ file_info += f"**File Type:** {ANALYZED_DATA['type']}\n\n"
311
+
312
+ if ANALYZED_DATA['type'] in ['csv', 'excel']:
313
+ summary = ANALYZED_DATA['summary']
314
+ file_info += f"**Rows:** {summary['rows']}\n"
315
+ file_info += f"**Columns:** {summary['columns']}\n"
316
+ file_info += f"**Column Names:** {', '.join(summary['column_names'])}\n\n"
317
 
318
+ # Sample data preview
319
+ file_info += "**Sample Data (First 5 rows):**\n"
320
+ sample_df = pd.DataFrame(summary['sample'])
321
+ file_info += sample_df.to_markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ elif ANALYZED_DATA['type'] == 'text':
324
+ summary = ANALYZED_DATA['summary']
325
+ file_info += f"**Length:** {summary['length']} characters\n"
326
+ file_info += f"**Word Count:** {summary['word_count']}\n"
327
+ file_info += f"**Line Count:** {summary['line_count']}\n\n"
328
+ file_info += "**Preview:**\n```\n" + summary['preview'] + "\n```"
329
 
330
+ elif ANALYZED_DATA['type'] == 'json':
331
+ summary = ANALYZED_DATA['summary']
332
+ file_info += f"**Type:** {summary['type']}\n"
333
+ if summary['keys']:
334
+ file_info += f"**Keys:** {', '.join(summary['keys'])}\n"
335
+ if summary['length'] is not None:
336
+ file_info += f"**Length:** {summary['length']} items\n"
337
+ file_info += "\n**Preview:**\n```json\n" + summary['preview'] + "\n```"
338
+
339
+ return file_info
340
+
341
+ # Function to select current chat
342
+ def select_chat(chat_name):
343
+ global CURRENT_CHAT
344
+ CURRENT_CHAT = chat_name
345
+ return f"Switched to chat: {chat_name}"
346
+
347
+ # Function to clear current chat
348
+ def clear_current_chat():
349
+ global CHATS, CURRENT_CHAT
350
+ CHATS[CURRENT_CHAT] = []
351
+ return f"Cleared chat: {CURRENT_CHAT}"
352
+
353
+ # Function to load model and return status
354
+ def load_model_button():
355
+ if MODEL_LOADED:
356
+ return "Model is already loaded and ready!"
357
+ elif MODEL_LOADING:
358
+ return "Model is currently loading... Please wait."
359
+ else:
360
+ thread = Thread(target=load_model_in_background)
361
+ thread.start()
362
+ return "Started loading the model. This may take a few minutes..."
363
+
364
+ # Function to get available chats
365
+ def get_available_chats():
366
+ global CHATS
367
+ return list(CHATS.keys())
368
 
369
+ # Main Gradio app
370
+ def create_gradio_interface():
371
+ # CSS for better styling
372
+ css = """
373
+ .gradio-container {max-width: 100% !important; padding: 0}
374
+ .chat-message-user {background-color: #e0f7fa; padding: 12px; border-radius: 8px; margin-bottom: 8px}
375
+ .chat-message-bot {background-color: #f1f8e9; padding: 12px; border-radius: 8px; margin-bottom: 8px}
376
+ .file-info {border: 1px solid #ddd; padding: 15px; border-radius: 5px; margin-top: 10px}
377
+ """
378
 
379
+ # Setup tabs for different functionalities
380
+ with gr.Blocks(css=css) as app:
381
+ gr.Markdown("# 🤖 Advanced Mistral-7B-Instruct Chatbot")
382
 
383
+ with gr.Tab("Chat"):
384
+ with gr.Row():
385
+ with gr.Column(scale=3):
386
+ chatbot = gr.Chatbot(
387
+ [],
388
+ elem_id="chatbot",
389
+ height=500,
390
+ bubble_full_width=False,
391
+ avatar_images=(None, None),
392
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ with gr.Row():
395
+ msg = gr.Textbox(
396
+ placeholder="Type your message here...",
397
+ container=False,
398
+ scale=8,
399
+ autofocus=True
400
+ )
401
+ send_btn = gr.Button("Send", scale=1, variant="primary")
402
 
403
+ with gr.Row():
404
+ chat_selector = gr.Dropdown(
405
+ choices=get_available_chats(),
406
+ value=CURRENT_CHAT,
407
+ label="Select Chat",
408
+ interactive=True
409
+ )
410
+
411
+ with gr.Column(scale=1):
412
+ new_chat_name = gr.Textbox(
413
+ placeholder="New chat name",
414
+ label="Create New Chat",
415
+ container=True
416
+ )
417
+ with gr.Row():
418
+ create_chat_btn = gr.Button("Create", variant="secondary")
419
+ clear_chat_btn = gr.Button("Clear Current Chat", variant="secondary")
420
 
421
+ with gr.Column(scale=1):
422
+ # Model Loading and Settings
423
+ load_model_btn = gr.Button("Load Mistral-7B Model", variant="primary")
424
+ model_status = gr.Textbox(label="Model Status", value="Not loaded", interactive=False)
425
+
426
+ # System Prompt
427
+ system_prompt_input = gr.Textbox(
428
+ label="System Prompt",
429
+ value=SYSTEM_PROMPT,
430
+ lines=4,
431
+ placeholder="Enter system prompt to guide the AI's behavior..."
432
+ )
433
+ update_prompt_btn = gr.Button("Update System Prompt", variant="secondary")
434
+
435
+ # Model Parameters
436
+ gr.Markdown("### Generation Parameters")
437
+ temperature = gr.Slider(
438
+ minimum=0.1,
439
+ maximum=2.0,
440
+ value=GENERATE_CONFIG["temperature"],
441
+ step=0.1,
442
+ label="Temperature"
443
+ )
444
+ max_tokens = gr.Slider(
445
+ minimum=64,
446
+ maximum=2048,
447
+ value=GENERATE_CONFIG["max_new_tokens"],
448
+ step=64,
449
+ label="Max Tokens"
450
+ )
451
+ top_p = gr.Slider(
452
+ minimum=0.1,
453
+ maximum=1.0,
454
+ value=GENERATE_CONFIG["top_p"],
455
+ step=0.05,
456
+ label="Top P"
457
+ )
458
+ rep_penalty = gr.Slider(
459
+ minimum=1.0,
460
+ maximum=2.0,
461
+ value=GENERATE_CONFIG["repetition_penalty"],
462
+ step=0.1,
463
+ label="Repetition Penalty"
464
+ )
465
+ update_params_btn = gr.Button("Update Parameters", variant="secondary")
466
+
467
+ with gr.Tab("File Analysis"):
468
+ with gr.Row():
469
+ with gr.Column(scale=1):
470
+ file_upload = gr.File(label="Upload a file to analyze")
471
+ analyze_btn = gr.Button("Analyze File", variant="primary")
472
 
473
+ with gr.Column(scale=2):
474
+ file_analysis_output = gr.Markdown(label="File Analysis Results")
 
 
 
 
 
 
 
 
 
475
 
476
+ # Set up event handlers
477
+ send_btn.click(
478
+ generate_response,
479
+ inputs=[msg, chatbot],
480
+ outputs=[chatbot],
481
+ api_name="chat"
482
+ )
483
 
484
+ msg.submit(
485
+ generate_response,
486
+ inputs=[msg, chatbot],
487
+ outputs=[chatbot],
488
+ api_name=False
489
+ )
490
 
491
+ load_model_btn.click(
492
+ load_model_button,
493
+ outputs=model_status,
494
+ api_name="load_model"
495
+ )
496
+
497
+ update_prompt_btn.click(
498
+ update_system_prompt,
499
+ inputs=system_prompt_input,
500
+ outputs=model_status,
501
+ api_name="update_prompt"
502
+ )
503
+
504
+ update_params_btn.click(
505
+ update_generation_params,
506
+ inputs=[temperature, max_tokens, top_p, rep_penalty],
507
+ outputs=model_status,
508
+ api_name="update_params"
509
+ )
510
+
511
+ analyze_btn.click(
512
+ analyze_uploaded_file,
513
+ inputs=file_upload,
514
+ outputs=model_status,
515
+ api_name="analyze_file"
516
+ ).then(
517
+ display_file_info,
518
+ outputs=file_analysis_output,
519
+ api_name=False
520
+ )
521
 
522
+ chat_selector.change(
523
+ select_chat,
524
+ inputs=chat_selector,
525
+ outputs=model_status,
526
+ api_name="select_chat"
527
+ )
528
+
529
+ create_chat_btn.click(
530
+ create_new_chat,
531
+ inputs=new_chat_name,
532
+ outputs=model_status,
533
+ api_name="create_chat"
534
+ ).then(
535
+ get_available_chats,
536
+ outputs=chat_selector,
537
+ api_name=False
538
+ )
539
+
540
+ clear_chat_btn.click(
541
+ clear_current_chat,
542
+ outputs=model_status,
543
+ api_name="clear_chat"
544
+ )
545
+
546
+ # Refresh example messages every few seconds
547
+ chat_selector.change(lambda: [], outputs=chatbot)
548
+
549
+ return app
550
+
551
+ # Launch the app
552
+ demo = create_gradio_interface()
553
 
554
+ if __name__ == "__main__":
555
+ demo.launch()