|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_groq import ChatGroq |
|
|
import plotly.express as px |
|
|
import time |
|
|
import plotly.io as pio |
|
|
import traceback |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
import importlib.util |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama3-8b-8192", |
|
|
temperature=0, |
|
|
max_tokens=None, |
|
|
timeout=None, |
|
|
max_retries=2, |
|
|
verbose=True, |
|
|
api_key=os.getenv("GROQ_API_KEY") |
|
|
) |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
|
|
|
current_plot = None |
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. |
|
|
|
|
|
Important guidelines: |
|
|
1. Use SQLite syntax (not PostgreSQL or MySQL) |
|
|
2. For date functions, use strftime() instead of EXTRACT |
|
|
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) |
|
|
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead |
|
|
4. For percentiles, use window functions or approximate methods |
|
|
5. Keep queries efficient and focused on answering the specific question |
|
|
6. Always use 'data_tab' as the table name |
|
|
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks |
|
|
|
|
|
Question: {question} |
|
|
""") |
|
|
|
|
|
|
|
|
interpret_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), |
|
|
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
visualization_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. |
|
|
|
|
|
Important guidelines for SQLite syntax: |
|
|
1. Use strftime() for date functions: |
|
|
- Year: strftime('%Y', date_column) |
|
|
- Month: strftime('%m', date_column) |
|
|
- Day: strftime('%d', date_column) |
|
|
- Hour: strftime('%H', date_column) |
|
|
|
|
|
2. For histograms and binning: |
|
|
- Use: CAST((column / bin_size) AS INT) * bin_size |
|
|
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin |
|
|
|
|
|
3. For box plots: |
|
|
- SQLite doesn't support PERCENTILE_CONT or window functions |
|
|
- Simply return the raw data column: SELECT column_name FROM data_tab |
|
|
- The application will calculate quartiles and outliers |
|
|
|
|
|
4. For heatmaps: |
|
|
- Return raw data for correlation analysis |
|
|
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab |
|
|
|
|
|
5. Always use 'data_tab' as the table name |
|
|
|
|
|
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks |
|
|
|
|
|
Question: {question} |
|
|
Visualization type: {viz_type} |
|
|
""") |
|
|
|
|
|
|
|
|
def clean_sql_query(query_text): |
|
|
"""Clean SQL query text by removing markdown formatting and comments""" |
|
|
|
|
|
if not query_text: |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
|
|
|
if "```" in query_text: |
|
|
|
|
|
pattern = r"```(?:sql)?(.*?)```" |
|
|
matches = re.findall(pattern, query_text, re.DOTALL) |
|
|
if matches: |
|
|
query_text = matches[0].strip() |
|
|
|
|
|
|
|
|
prefixes = [ |
|
|
"here is the sql query", |
|
|
"here is the sqlite query", |
|
|
"here is a query", |
|
|
"here's the sql query", |
|
|
"the sql query is", |
|
|
"sql query:" |
|
|
] |
|
|
|
|
|
for prefix in prefixes: |
|
|
if query_text.lower().startswith(prefix): |
|
|
|
|
|
sql_keywords = ["select", "with", "create", "insert", "update", "delete"] |
|
|
positions = [query_text.lower().find(keyword) for keyword in sql_keywords] |
|
|
positions = [pos for pos in positions if pos != -1] |
|
|
|
|
|
if positions: |
|
|
start_pos = min(positions) |
|
|
query_text = query_text[start_pos:] |
|
|
|
|
|
|
|
|
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) |
|
|
|
|
|
|
|
|
query_text = query_text.strip().rstrip(';') |
|
|
|
|
|
|
|
|
if not query_text.strip(): |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
return query_text |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
viz_keywords = { |
|
|
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], |
|
|
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], |
|
|
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], |
|
|
'histogram': ['histogram', 'distribution of', 'frequency distribution'], |
|
|
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], |
|
|
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], |
|
|
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] |
|
|
} |
|
|
|
|
|
|
|
|
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) |
|
|
|
|
|
|
|
|
viz_type = None |
|
|
if is_visualization: |
|
|
for vtype, keywords in viz_keywords.items(): |
|
|
if any(keyword in query.lower() for keyword in keywords): |
|
|
viz_type = vtype |
|
|
break |
|
|
|
|
|
|
|
|
if current_context["file_type"] == "csv" and current_context["table_name"]: |
|
|
try: |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute(f"PRAGMA table_info({current_context['table_name']});") |
|
|
columns = [info[1] for info in cursor.fetchall()] |
|
|
columns_str = ", ".join(columns) |
|
|
|
|
|
|
|
|
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" |
|
|
|
|
|
|
|
|
if is_visualization and viz_type in ['box', 'heatmap']: |
|
|
|
|
|
if viz_type == 'box': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if numeric_cols: |
|
|
|
|
|
target_col = None |
|
|
for col in numeric_cols: |
|
|
if col.lower() in query.lower(): |
|
|
target_col = col |
|
|
break |
|
|
|
|
|
|
|
|
if not target_col and numeric_cols: |
|
|
target_col = numeric_cols[0] |
|
|
|
|
|
|
|
|
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" |
|
|
else: |
|
|
|
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
elif viz_type == 'heatmap': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
|
|
|
cols_to_use = numeric_cols[:10] |
|
|
cols_str = ", ".join(cols_to_use) |
|
|
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" |
|
|
else: |
|
|
|
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
else: |
|
|
|
|
|
ai_msg = query_prompt | llm |
|
|
raw_sql_query = ai_msg.invoke({"question": question_with_context}).content.strip() |
|
|
|
|
|
|
|
|
sql_query = clean_sql_query(raw_sql_query) |
|
|
|
|
|
print(f"Generated SQL Query: {sql_query}") |
|
|
|
|
|
try: |
|
|
|
|
|
result_df = pd.read_sql_query(sql_query, conn) |
|
|
|
|
|
|
|
|
if not result_df.empty: |
|
|
data_summary = result_df.describe(include='all').to_string() |
|
|
|
|
|
|
|
|
if len(result_df) <= 10: |
|
|
data_summary += f"\n\nFull Results:\n{result_df.to_string()}" |
|
|
else: |
|
|
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}" |
|
|
else: |
|
|
data_summary = "No relevant data found." |
|
|
|
|
|
|
|
|
answer_chain = interpret_prompt | llm |
|
|
interpretation = answer_chain.invoke({ |
|
|
"question": query, |
|
|
"sql_query": sql_query, |
|
|
"data_summary": data_summary |
|
|
}).content.strip() |
|
|
|
|
|
|
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n" |
|
|
|
|
|
if not result_df.empty: |
|
|
if len(result_df) > 10: |
|
|
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n" |
|
|
else: |
|
|
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n" |
|
|
else: |
|
|
response += "**No results found.**\n\n" |
|
|
|
|
|
response += f"**Analysis:**\n{interpretation}" |
|
|
|
|
|
|
|
|
if is_visualization and not result_df.empty: |
|
|
try: |
|
|
|
|
|
viz_html = generate_visualization(result_df, query) |
|
|
|
|
|
if viz_html: |
|
|
|
|
|
response += f"\n\n{viz_html}" |
|
|
|
|
|
|
|
|
response += "\n\n**A visualization has been generated and is displayed above.**" |
|
|
else: |
|
|
response += "\n\n**Could not generate visualization due to an error.**" |
|
|
|
|
|
except Exception as viz_error: |
|
|
print(f"Visualization error: {str(viz_error)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}" |
|
|
|
|
|
conn.close() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"Error processing query: {str(e)}" |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
response = document_assistant.process_query(query) |
|
|
except Exception as e: |
|
|
response = f"Error processing document query: {str(e)}" |
|
|
|
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
response += f"\n\n(Query processed in {processing_time:.2f} seconds)" |
|
|
|
|
|
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
return "", history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
global current_context |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
conn.execute("PRAGMA synchronous = OFF") |
|
|
conn.execute("PRAGMA journal_mode = MEMORY") |
|
|
|
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
df.to_sql('data_tab', conn, if_exists='replace', index=False) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "csv", |
|
|
"file_name": file_name, |
|
|
"table_name": "data_tab" |
|
|
} |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(data_tab);") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM data_tab;") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
conn.close() |
|
|
|
|
|
file_info.append("β
CSV File Successfully Loaded") |
|
|
file_info.append(f"π Table Name: data_tab") |
|
|
file_info.append(f"π Source File: {file_name}") |
|
|
file_info.append(f"π Total Rows: {row_count:,}") |
|
|
file_info.append(f"π Columns: {', '.join(columns)}") |
|
|
|
|
|
except Exception as e: |
|
|
file_info.append(f"β Error loading CSV {file_name}: {str(e)}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
result = document_assistant.upload_document(file_path) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "pdf", |
|
|
"file_name": file_name, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info.append("β
Document Successfully Processed") |
|
|
file_info.append(f"π File: {file_name}") |
|
|
file_info.append(f"π Chunks: {result['chunks']}") |
|
|
file_info.append(result['message']) |
|
|
except Exception as e: |
|
|
file_info.append(f"β Error processing document {file_name}: {str(e)}") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
def list_documents(): |
|
|
"""List all indexed documents""" |
|
|
try: |
|
|
docs = document_assistant.get_all_documents() |
|
|
if not docs: |
|
|
return "No documents indexed yet." |
|
|
|
|
|
result = "Indexed Documents:\n\n" |
|
|
for doc in docs: |
|
|
result += f"- {doc['filename']} ({doc['file_type']})\n" |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Error listing documents: {str(e)}" |
|
|
|
|
|
def clear_context(): |
|
|
"""Clear the current context""" |
|
|
global current_context |
|
|
|
|
|
try: |
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return [{"role": "assistant", "content": "Context cleared. You can now upload new documents or CSV files."}] |
|
|
except Exception as e: |
|
|
return [{"role": "assistant", "content": f"Error clearing context: {str(e)}"}] |
|
|
|
|
|
|
|
|
""" |
|
|
def process_voice_input(audio_path): |
|
|
# I am checking if there is audio |
|
|
if audio_path is None: |
|
|
return "No audio recorded" |
|
|
|
|
|
try: |
|
|
# I am making a recognizer for the voice |
|
|
r = sr.Recognizer() |
|
|
|
|
|
# I am loading the audio file |
|
|
with sr.AudioFile(audio_path) as source: |
|
|
# I am reading the audio data |
|
|
audio_data = r.record(source) |
|
|
|
|
|
# I am recognizing speech using Google |
|
|
text = r.recognize_google(audio_data) |
|
|
|
|
|
return text |
|
|
except sr.UnknownValueError: |
|
|
return "Could not understand audio" |
|
|
except sr.RequestError as e: |
|
|
return f"Error with speech recognition service: {e}" |
|
|
except Exception as e: |
|
|
return f"Error processing audio: {str(e)}" |
|
|
""" |
|
|
|
|
|
|
|
|
""" |
|
|
def text_to_speech_output(text): |
|
|
# I am checking if there is text |
|
|
if not text or len(text) == 0: |
|
|
return None |
|
|
|
|
|
# I am finding the last message from assistant |
|
|
last_message = None |
|
|
for msg in reversed(text): |
|
|
if msg["role"] == "assistant": |
|
|
last_message = msg["content"] |
|
|
break |
|
|
|
|
|
if not last_message: |
|
|
return None |
|
|
|
|
|
try: |
|
|
# I am cleaning the text |
|
|
clean_text = re.sub(r'<.*?>', '', last_message) # I am removing HTML tags |
|
|
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) # I am removing bold markdown |
|
|
clean_text = re.sub(r'\n\n', ' ', clean_text) # I am replacing double newlines with space |
|
|
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) # I am replacing code blocks |
|
|
|
|
|
# I am creating a temporary file |
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
|
temp_file.close() |
|
|
|
|
|
# I am generating speech |
|
|
tts = gTTS(text=clean_text, lang='en', slow=False) |
|
|
tts.save(temp_file.name) |
|
|
|
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
print(f"Error generating speech: {str(e)}") |
|
|
return None |
|
|
""" |
|
|
|
|
|
def create_test_visualization(): |
|
|
"""Create a test visualization to verify plotting works""" |
|
|
try: |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
autosize=True, |
|
|
width=800, |
|
|
height=500 |
|
|
) |
|
|
|
|
|
return fig |
|
|
except Exception as e: |
|
|
print(f"Error creating test visualization: {str(e)}") |
|
|
return None |
|
|
|
|
|
def create_test_html_visualization(): |
|
|
"""Create a test HTML visualization to verify plotting works""" |
|
|
try: |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
html = pio.to_html(fig, full_html=False) |
|
|
|
|
|
return html |
|
|
except Exception as e: |
|
|
print(f"Error creating test HTML visualization: {str(e)}") |
|
|
return None |
|
|
|
|
|
def flush_databases(): |
|
|
"""Flush ChromaDB and SQLite databases""" |
|
|
global document_assistant |
|
|
global current_context |
|
|
result = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
|
|
|
for table in tables: |
|
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
result.append("β
SQLite database cleared successfully") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing SQLite database: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
success = document_assistant.reset_database() |
|
|
if success: |
|
|
result.append("β
ChromaDB cleared successfully") |
|
|
else: |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
result.append("β οΈ ChromaDB reset partially completed - created new instance") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing ChromaDB: {str(e)}") |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
except NameError as e: |
|
|
if "response" in str(e): |
|
|
|
|
|
import backend.vector_db |
|
|
|
|
|
|
|
|
if hasattr(backend.vector_db, 'response'): |
|
|
delattr(backend.vector_db, 'response') |
|
|
|
|
|
|
|
|
importlib.reload(backend.vector_db) |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
|
|
|
|
|
|
def generate_visualization(result_df, query): |
|
|
"""Generate a visualization based on the query and data""" |
|
|
try: |
|
|
print("Visualization requested, attempting to create plot...") |
|
|
|
|
|
|
|
|
fig_width = 900 |
|
|
fig_height = 800 |
|
|
|
|
|
|
|
|
viz_type = 'bar' |
|
|
|
|
|
if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']): |
|
|
viz_type = 'pie' |
|
|
elif any(word in query.lower() for word in ['line', 'trend', 'time series']): |
|
|
viz_type = 'line' |
|
|
elif any(word in query.lower() for word in ['scatter', 'relationship']): |
|
|
viz_type = 'scatter' |
|
|
elif any(word in query.lower() for word in ['histogram', 'distribution of']): |
|
|
viz_type = 'histogram' |
|
|
elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']): |
|
|
viz_type = 'box' |
|
|
elif any(word in query.lower() for word in ['heatmap', 'correlation']): |
|
|
viz_type = 'heatmap' |
|
|
|
|
|
print(f"Creating {viz_type} visualization...") |
|
|
|
|
|
|
|
|
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist() |
|
|
|
|
|
|
|
|
if viz_type == 'pie' and len(result_df) <= 20: |
|
|
|
|
|
labels = result_df.iloc[:, 0].tolist() |
|
|
values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df) |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
fig = go.Figure(data=[go.Pie(labels=labels, values=values)]) |
|
|
fig.update_layout(title_text='Pie Chart') |
|
|
|
|
|
elif viz_type == 'histogram' and len(numeric_cols) > 0: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.histogram(result_df, x=numeric_cols[0]) |
|
|
fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}') |
|
|
|
|
|
elif viz_type == 'box' and len(numeric_cols) > 0: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.box(result_df, y=numeric_cols[0]) |
|
|
fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}') |
|
|
|
|
|
elif viz_type == 'heatmap' and len(numeric_cols) >= 2: |
|
|
|
|
|
import plotly.express as px |
|
|
|
|
|
corr_df = result_df[numeric_cols].corr() |
|
|
fig = px.imshow(corr_df, text_auto=True) |
|
|
fig.update_layout(title_text='Correlation Heatmap') |
|
|
|
|
|
elif viz_type == 'scatter' and len(numeric_cols) >= 2: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1]) |
|
|
fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}') |
|
|
|
|
|
elif viz_type == 'line': |
|
|
|
|
|
import plotly.express as px |
|
|
x_col = result_df.columns[0] |
|
|
y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None |
|
|
|
|
|
if y_cols: |
|
|
fig = px.line(result_df, x=x_col, y=y_cols[0]) |
|
|
fig.update_layout( |
|
|
title_text=f'Line Chart of {y_cols[0]} over {x_col}', |
|
|
xaxis=dict( |
|
|
tickangle=-45, |
|
|
tickmode='auto', |
|
|
nticks=20 |
|
|
) |
|
|
) |
|
|
else: |
|
|
|
|
|
viz_type = 'bar' |
|
|
|
|
|
if viz_type == 'bar' or 'fig' not in locals(): |
|
|
|
|
|
import plotly.express as px |
|
|
x_col = result_df.columns[0] |
|
|
y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None |
|
|
|
|
|
|
|
|
if len(result_df) > 10: |
|
|
|
|
|
if y_col: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
y=x_col, |
|
|
x=y_col, |
|
|
orientation='h', |
|
|
title=f'Bar Chart of {y_col} by {x_col}' |
|
|
) |
|
|
else: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
y=x_col, |
|
|
orientation='h', |
|
|
title=f'Bar Chart of {x_col}' |
|
|
) |
|
|
else: |
|
|
|
|
|
if y_col: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_col, |
|
|
title=f'Bar Chart of {y_col} by {x_col}', |
|
|
width=900, |
|
|
height=800 |
|
|
) |
|
|
else: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=x_col, |
|
|
title=f'Bar Chart of {x_col}', |
|
|
width=900, |
|
|
height=800 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
bargap=0.2, |
|
|
uniformtext_minsize=8, |
|
|
uniformtext_mode='hide' |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
width=fig_width, |
|
|
height=fig_height, |
|
|
template="plotly_white", |
|
|
margin=dict(l=40, r=40, t=80, b=80, pad=4), |
|
|
autosize=True, |
|
|
plot_bgcolor='rgba(240,240,240,0.2)', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
|
|
|
print(f"Created figure with width={fig_width}, height={fig_height}") |
|
|
|
|
|
|
|
|
print("Converting figure to image...") |
|
|
img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=2) |
|
|
print("Image conversion successful") |
|
|
|
|
|
|
|
|
import base64 |
|
|
encoded = base64.b64encode(img_bytes).decode("ascii") |
|
|
img_src = f"data:image/png;base64,{encoded}" |
|
|
|
|
|
print("HTML conversion successful") |
|
|
|
|
|
|
|
|
return f"<img src='{img_src}' width='100%' style='max-width:900px; height:800px; object-fit:contain; display:block; margin:0 auto;' />" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"Error generating visualization: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="LLM Powered Database Chatbot") as demo: |
|
|
gr.Markdown("# π€ LLM Powered Database Chatbot") |
|
|
gr.Markdown("Upload documents, ask questions, and get AI-powered responses!") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.chatbot-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 10px 0; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
### Capabilities: |
|
|
- **Data Analysis**: Ask questions about your data and get detailed responses |
|
|
- **Visualization**: Request and view graphs and charts of your data |
|
|
- **Multiple File Types**: Upload PDFs, TXT, DOCX, CSV, and XLSX files for analysis |
|
|
- **Natural Language Queries**: Ask questions in plain English about your documents |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
clear_context_btn = gr.Button("Clear Context") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
voice_input = gr.Audio( |
|
|
label="Voice Input", |
|
|
type="filepath", |
|
|
visible=False |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, [chatbot], queue=False) |
|
|
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot]) |
|
|
|
|
|
|
|
|
""" |
|
|
voice_btn.click( |
|
|
lambda: gr.update(visible=True), |
|
|
None, |
|
|
voice_input |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
""" |
|
|
voice_input.change( |
|
|
process_voice_input, |
|
|
inputs=[voice_input], |
|
|
outputs=[msg] |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
""" |
|
|
tts_btn = gr.Button("π Speak Response") |
|
|
tts_btn.click( |
|
|
text_to_speech_output, |
|
|
inputs=[chatbot], |
|
|
outputs=[audio_output] |
|
|
) |
|
|
""" |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
upload_button = gr.Button("Process & Index Documents", scale=2) |
|
|
flush_db_btn_doc = gr.Button("ποΈ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
flush_db_btn_doc.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
list_docs_button = gr.Button("List Indexed Documents") |
|
|
docs_output = gr.Textbox(label="Indexed Documents") |
|
|
|
|
|
list_docs_button.click( |
|
|
list_documents, |
|
|
inputs=[], |
|
|
outputs=[docs_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
with gr.Row(): |
|
|
gr.Markdown("## Database Management") |
|
|
flush_db_btn = gr.Button("ποΈ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
flush_result = gr.Textbox(label="Flush Result") |
|
|
|
|
|
flush_db_btn.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[flush_result] |
|
|
) |
|
|
|
|
|
gr.Markdown("## System Settings") |
|
|
api_key = gr.Textbox( |
|
|
label="Groq API Key", |
|
|
placeholder="Enter your Groq API key", |
|
|
type="password", |
|
|
value=os.getenv("GROQ_API_KEY", "") |
|
|
) |
|
|
save_btn = gr.Button("Save Settings") |
|
|
|
|
|
def save_settings(key): |
|
|
try: |
|
|
os.environ["GROQ_API_KEY"] = key |
|
|
return "Settings saved!" |
|
|
except Exception as e: |
|
|
return f"Error saving settings: {str(e)}" |
|
|
|
|
|
save_btn.click( |
|
|
save_settings, |
|
|
inputs=[api_key], |
|
|
outputs=[gr.Textbox(label="Status")] |
|
|
) |
|
|
|
|
|
gr.Markdown("## Debugging") |
|
|
test_viz_btn = gr.Button("Test Visualization") |
|
|
test_viz_output = gr.HTML(label="Test Visualization") |
|
|
|
|
|
test_viz_btn.click( |
|
|
create_test_html_visualization, |
|
|
inputs=[], |
|
|
outputs=[test_viz_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |