|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_groq import ChatGroq |
|
|
import plotly.express as px |
|
|
import time |
|
|
import plotly.io as pio |
|
|
import traceback |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
|
|
|
import re |
|
|
import importlib.util |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama3-8b-8192", |
|
|
temperature=0, |
|
|
max_tokens=None, |
|
|
timeout=None, |
|
|
max_retries=2, |
|
|
verbose=True, |
|
|
api_key=os.getenv("GROQ_API_KEY") |
|
|
) |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") |
|
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
CHROMA_DB_DIR = os.path.join(DATA_DIR, "chroma_db") |
|
|
os.makedirs(CHROMA_DB_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
os.environ["CHROMA_DB_PATH"] = CHROMA_DB_DIR |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. |
|
|
|
|
|
Important guidelines: |
|
|
1. Use SQLite syntax (not PostgreSQL or MySQL) |
|
|
2. For date functions, use strftime() instead of EXTRACT |
|
|
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) |
|
|
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead |
|
|
4. For percentiles, use window functions or approximate methods |
|
|
5. Keep queries efficient and focused on answering the specific question |
|
|
6. Always use 'data_tab' as the table name |
|
|
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks |
|
|
|
|
|
Question: {question} |
|
|
""") |
|
|
|
|
|
|
|
|
interpret_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), |
|
|
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_sql_query(query_text): |
|
|
"""Clean SQL query text by removing markdown formatting and comments""" |
|
|
|
|
|
if not query_text: |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
|
|
|
if "```" in query_text: |
|
|
|
|
|
pattern = r"```(?:sql)?(.*?)```" |
|
|
matches = re.findall(pattern, query_text, re.DOTALL) |
|
|
if matches: |
|
|
query_text = matches[0].strip() |
|
|
|
|
|
|
|
|
prefixes = [ |
|
|
"here is the sql query", |
|
|
"here is the sqlite query", |
|
|
"here is a query", |
|
|
"here's the sql query", |
|
|
"the sql query is", |
|
|
"sql query:" |
|
|
] |
|
|
|
|
|
for prefix in prefixes: |
|
|
if query_text.lower().startswith(prefix): |
|
|
|
|
|
sql_keywords = ["select", "with", "create", "insert", "update", "delete"] |
|
|
positions = [query_text.lower().find(keyword) for keyword in sql_keywords] |
|
|
positions = [pos for pos in positions if pos != -1] |
|
|
|
|
|
if positions: |
|
|
start_pos = min(positions) |
|
|
query_text = query_text[start_pos:] |
|
|
|
|
|
|
|
|
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) |
|
|
|
|
|
|
|
|
query_text = query_text.strip().rstrip(';') |
|
|
|
|
|
|
|
|
if not query_text.strip(): |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
return query_text |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
history.append([query, None]) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
viz_keywords = { |
|
|
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], |
|
|
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], |
|
|
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], |
|
|
'histogram': ['histogram', 'distribution of', 'frequency distribution'], |
|
|
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], |
|
|
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], |
|
|
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] |
|
|
} |
|
|
|
|
|
|
|
|
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) |
|
|
|
|
|
|
|
|
viz_type = None |
|
|
if is_visualization: |
|
|
for vtype, keywords in viz_keywords.items(): |
|
|
if any(keyword in query.lower() for keyword in keywords): |
|
|
viz_type = vtype |
|
|
break |
|
|
|
|
|
|
|
|
if current_context["file_type"] == "csv" and current_context["table_name"]: |
|
|
try: |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute(f"PRAGMA table_info({current_context['table_name']});") |
|
|
columns = [info[1] for info in cursor.fetchall()] |
|
|
columns_str = ", ".join(columns) |
|
|
|
|
|
|
|
|
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" |
|
|
|
|
|
|
|
|
if is_visualization and viz_type in ['box', 'heatmap']: |
|
|
|
|
|
if viz_type == 'box': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if numeric_cols: |
|
|
|
|
|
target_col = None |
|
|
for col in numeric_cols: |
|
|
if col.lower() in query.lower(): |
|
|
target_col = col |
|
|
break |
|
|
|
|
|
|
|
|
if not target_col and numeric_cols: |
|
|
target_col = numeric_cols[0] |
|
|
|
|
|
|
|
|
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" |
|
|
else: |
|
|
|
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
elif viz_type == 'heatmap': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
|
|
|
cols_to_use = numeric_cols[:10] |
|
|
cols_str = ", ".join(cols_to_use) |
|
|
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" |
|
|
else: |
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
else: |
|
|
|
|
|
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content |
|
|
sql_query = clean_sql_query(sql_query) |
|
|
|
|
|
|
|
|
result_df = pd.read_sql_query(sql_query, conn) |
|
|
|
|
|
|
|
|
conn.close() |
|
|
|
|
|
|
|
|
df_str = result_df.to_string() |
|
|
|
|
|
|
|
|
data_summary = result_df.to_string() |
|
|
analysis = llm.invoke(interpret_prompt.format( |
|
|
question=query, |
|
|
sql_query=sql_query, |
|
|
data_summary=data_summary |
|
|
)).content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
comprehensive_response = f""" |
|
|
### SQL Query: |
|
|
```sql |
|
|
{sql_query} |
|
|
``` |
|
|
|
|
|
### Results: |
|
|
``` |
|
|
{df_str} |
|
|
``` |
|
|
|
|
|
### Analysis: |
|
|
{analysis} |
|
|
""" |
|
|
|
|
|
|
|
|
if is_visualization: |
|
|
viz_html = generate_visualization(result_df, query) |
|
|
if viz_html: |
|
|
|
|
|
history[-1][1] = comprehensive_response |
|
|
return viz_html, history |
|
|
|
|
|
|
|
|
history[-1][1] = comprehensive_response |
|
|
return comprehensive_response, history |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error processing query: {str(e)}" |
|
|
history[-1][1] = error_msg |
|
|
return error_msg, history |
|
|
|
|
|
elif document_assistant.get_all_documents(): |
|
|
|
|
|
try: |
|
|
response = document_assistant.process_query(query) |
|
|
history[-1][1] = response |
|
|
return response, history |
|
|
except Exception as e: |
|
|
error_msg = f"Error processing query: {str(e)}" |
|
|
history[-1][1] = error_msg |
|
|
return error_msg, history |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
general_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."), |
|
|
("human", "{question}") |
|
|
]) |
|
|
|
|
|
|
|
|
response = llm.invoke(general_prompt.format(question=query)).content |
|
|
|
|
|
|
|
|
history[-1][1] = response |
|
|
return response, history |
|
|
except Exception as e: |
|
|
error_msg = f"Error processing query: {str(e)}" |
|
|
history[-1][1] = error_msg |
|
|
return error_msg, history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
global current_context |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
conn.execute("PRAGMA synchronous = OFF") |
|
|
conn.execute("PRAGMA journal_mode = MEMORY") |
|
|
|
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
df.to_sql('data_tab', conn, if_exists='replace', index=False) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "csv", |
|
|
"file_name": file_name, |
|
|
"table_name": "data_tab" |
|
|
} |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(data_tab);") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM data_tab;") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
conn.close() |
|
|
|
|
|
file_info.append("✅ CSV File Successfully Loaded") |
|
|
file_info.append(f"📊 Table Name: data_tab") |
|
|
file_info.append(f"📄 Source File: {file_name}") |
|
|
file_info.append(f"📈 Total Rows: {row_count:,}") |
|
|
file_info.append(f"📋 Columns: {', '.join(columns)}") |
|
|
|
|
|
except Exception as e: |
|
|
file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
result = document_assistant.upload_document(file_path) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "pdf", |
|
|
"file_name": file_name, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info.append("✅ Document Successfully Processed") |
|
|
file_info.append(f"📄 File: {file_name}") |
|
|
file_info.append(f"📚 Chunks: {result['chunks']}") |
|
|
file_info.append(result['message']) |
|
|
except Exception as e: |
|
|
file_info.append(f"❌ Error processing document {file_name}: {str(e)}") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_context(): |
|
|
"""Clear the current context""" |
|
|
global current_context |
|
|
|
|
|
try: |
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return [["Context cleared. You can now upload new documents or CSV files.", None]] |
|
|
except Exception as e: |
|
|
return [[f"Error clearing context: {str(e)}", None]] |
|
|
|
|
|
def flush_databases(): |
|
|
"""Flush ChromaDB and SQLite databases""" |
|
|
global document_assistant |
|
|
global current_context |
|
|
result = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
|
|
|
for table in tables: |
|
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
result.append("✅ SQLite database cleared successfully") |
|
|
except Exception as e: |
|
|
result.append(f"❌ Error clearing SQLite database: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
success = document_assistant.reset_database() |
|
|
if success: |
|
|
result.append("✅ ChromaDB cleared successfully") |
|
|
else: |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
result.append("⚠️ ChromaDB reset partially completed - created new instance") |
|
|
except Exception as e: |
|
|
result.append(f"❌ Error clearing ChromaDB: {str(e)}") |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
except NameError as e: |
|
|
if "response" in str(e): |
|
|
|
|
|
import backend.vector_db |
|
|
|
|
|
|
|
|
if hasattr(backend.vector_db, 'response'): |
|
|
delattr(backend.vector_db, 'response') |
|
|
|
|
|
|
|
|
importlib.reload(backend.vector_db) |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
|
|
|
|
|
|
def generate_visualization(result_df, query): |
|
|
"""Generate a visualization based on the query and data""" |
|
|
try: |
|
|
print("Visualization requested, attempting to create plot...") |
|
|
|
|
|
|
|
|
fig_width = 1200 |
|
|
fig_height = 800 |
|
|
|
|
|
|
|
|
viz_type = 'bar' |
|
|
|
|
|
if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']): |
|
|
viz_type = 'pie' |
|
|
elif any(word in query.lower() for word in ['line', 'trend', 'time series']): |
|
|
viz_type = 'line' |
|
|
elif any(word in query.lower() for word in ['scatter', 'relationship']): |
|
|
viz_type = 'scatter' |
|
|
elif any(word in query.lower() for word in ['histogram', 'distribution of']): |
|
|
viz_type = 'histogram' |
|
|
elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']): |
|
|
viz_type = 'box' |
|
|
elif any(word in query.lower() for word in ['heatmap', 'correlation']): |
|
|
viz_type = 'heatmap' |
|
|
|
|
|
print(f"Creating {viz_type} visualization...") |
|
|
|
|
|
|
|
|
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist() |
|
|
|
|
|
|
|
|
if viz_type == 'pie' and len(result_df) <= 20: |
|
|
|
|
|
labels = result_df.iloc[:, 0].tolist() |
|
|
values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df) |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
fig = go.Figure(data=[go.Pie(labels=labels, values=values)]) |
|
|
fig.update_layout(title_text='Pie Chart') |
|
|
|
|
|
elif viz_type == 'histogram' and len(numeric_cols) > 0: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.histogram(result_df, x=numeric_cols[0]) |
|
|
fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}') |
|
|
|
|
|
elif viz_type == 'box' and len(numeric_cols) > 0: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.box(result_df, y=numeric_cols[0]) |
|
|
fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}') |
|
|
|
|
|
elif viz_type == 'heatmap' and len(numeric_cols) >= 2: |
|
|
|
|
|
import plotly.express as px |
|
|
|
|
|
corr_df = result_df[numeric_cols].corr() |
|
|
fig = px.imshow(corr_df, text_auto=True) |
|
|
fig.update_layout(title_text='Correlation Heatmap') |
|
|
|
|
|
elif viz_type == 'scatter' and len(numeric_cols) >= 2: |
|
|
|
|
|
import plotly.express as px |
|
|
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1]) |
|
|
fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}') |
|
|
|
|
|
elif viz_type == 'line': |
|
|
|
|
|
import plotly.express as px |
|
|
x_col = result_df.columns[0] |
|
|
y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None |
|
|
|
|
|
if y_cols: |
|
|
fig = px.line(result_df, x=x_col, y=y_cols[0]) |
|
|
fig.update_layout( |
|
|
title_text=f'Line Chart of {y_cols[0]} over {x_col}', |
|
|
xaxis=dict( |
|
|
tickangle=-45, |
|
|
tickmode='auto', |
|
|
nticks=20 |
|
|
) |
|
|
) |
|
|
else: |
|
|
|
|
|
viz_type = 'bar' |
|
|
|
|
|
if viz_type == 'bar' or 'fig' not in locals(): |
|
|
|
|
|
import plotly.express as px |
|
|
x_col = result_df.columns[0] |
|
|
y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None |
|
|
|
|
|
|
|
|
if len(result_df) > 10: |
|
|
|
|
|
if y_col: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
y=x_col, |
|
|
x=y_col, |
|
|
orientation='h', |
|
|
title=f'Bar Chart of {y_col} by {x_col}' |
|
|
) |
|
|
else: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
y=x_col, |
|
|
orientation='h', |
|
|
title=f'Bar Chart of {x_col}' |
|
|
) |
|
|
else: |
|
|
|
|
|
if y_col: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_col, |
|
|
title=f'Bar Chart of {y_col} by {x_col}' |
|
|
) |
|
|
else: |
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=x_col, |
|
|
title=f'Bar Chart of {x_col}' |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
bargap=0.2, |
|
|
uniformtext_minsize=8, |
|
|
uniformtext_mode='hide' |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
width=fig_width, |
|
|
height=fig_height, |
|
|
template="plotly_white", |
|
|
margin=dict(l=40, r=40, t=80, b=80, pad=4), |
|
|
autosize=True, |
|
|
plot_bgcolor='rgba(240,240,240,0.2)', |
|
|
paper_bgcolor='white', |
|
|
font=dict(size=12) |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_traces( |
|
|
hovertemplate="%{x}: %{y}<extra></extra>", |
|
|
hoverlabel=dict( |
|
|
bgcolor="white", |
|
|
font_size=12, |
|
|
font_family="Arial" |
|
|
) |
|
|
) |
|
|
|
|
|
print(f"Created figure with width={fig_width}, height={fig_height}") |
|
|
|
|
|
|
|
|
print("Converting figure to image...") |
|
|
img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=3) |
|
|
print("Image conversion successful") |
|
|
|
|
|
|
|
|
import base64 |
|
|
encoded = base64.b64encode(img_bytes).decode("ascii") |
|
|
img_src = f"data:image/png;base64,{encoded}" |
|
|
|
|
|
print("HTML conversion successful") |
|
|
|
|
|
|
|
|
return f""" |
|
|
<div class="visualization-wrapper"> |
|
|
<img src='{img_src}' |
|
|
style='max-width:100%; height:auto; display:block; margin:0 auto;' |
|
|
alt='Data Visualization' /> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"Error generating visualization: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="LLM Powered Database Chatbot") as demo: |
|
|
gr.Markdown("# 🤖 LLM Powered Database Chatbot") |
|
|
gr.Markdown("Upload documents, ask questions, and get AI-powered responses!") |
|
|
|
|
|
|
|
|
current_visualization = gr.State(None) |
|
|
|
|
|
with gr.Tab("Chat & Visualizations"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.chatbot-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 10px 0; |
|
|
} |
|
|
.visualization-container { |
|
|
min-height: 500px; |
|
|
max-height: 800px; |
|
|
overflow: auto; |
|
|
padding: 20px; |
|
|
background-color: #f8f9fa; |
|
|
border-radius: 8px; |
|
|
} |
|
|
.visualization-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 0 auto; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
chatbot = gr.Chatbot(height=500, elem_classes="chatbot-container") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
clear_context_btn = gr.Button("Clear Context") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
visualization_output = gr.HTML( |
|
|
label="Visualization", |
|
|
elem_classes="visualization-container" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
clear_viz_btn = gr.Button("🗑️ Clear Visualization") |
|
|
download_btn = gr.Button("📥 Download Visualization") |
|
|
|
|
|
save_status = gr.Textbox(label="Save Status", visible=False) |
|
|
download_img = gr.Image(visible=False, type="pil", label="Download Image") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
### Capabilities: |
|
|
- **Data Analysis**: Ask questions about your data and get detailed responses |
|
|
- **Visualization**: Request and view graphs and charts of your data |
|
|
- **Multiple File Types**: Upload PDFs, TXT, DOCX, CSV, and XLSX files for analysis |
|
|
- **Natural Language Queries**: Ask questions in plain English about your documents |
|
|
""") |
|
|
|
|
|
def clear_visualization(): |
|
|
return "", "" |
|
|
|
|
|
def download_visualization(viz_html): |
|
|
if not viz_html: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html) |
|
|
|
|
|
if img_data_match: |
|
|
|
|
|
img_data = img_data_match.group(1) |
|
|
|
|
|
|
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
|
|
|
image_data = base64.b64decode(img_data) |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
return image, gr.update(visible=True) |
|
|
else: |
|
|
return None, gr.update(visible=False) |
|
|
except Exception as e: |
|
|
print(f"Error downloading visualization: {str(e)}") |
|
|
return None, gr.update(visible=False) |
|
|
|
|
|
clear_viz_btn.click( |
|
|
clear_visualization, |
|
|
outputs=[visualization_output, current_visualization] |
|
|
) |
|
|
|
|
|
download_btn.click( |
|
|
download_visualization, |
|
|
inputs=[current_visualization], |
|
|
outputs=[download_img, download_img] |
|
|
) |
|
|
|
|
|
|
|
|
def process_text_query_with_visualization(query, history, current_viz): |
|
|
"""Process a text query and update chat history and visualization""" |
|
|
if not query: |
|
|
return "", history, current_viz |
|
|
|
|
|
|
|
|
response, new_history = process_text_query(query, history) |
|
|
|
|
|
|
|
|
if "<img src=" in response: |
|
|
|
|
|
viz_html = response |
|
|
|
|
|
current_viz = viz_html |
|
|
|
|
|
|
|
|
return "", new_history, current_viz |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query_with_visualization, |
|
|
inputs=[msg, chatbot, current_visualization], |
|
|
outputs=[msg, chatbot, current_visualization] |
|
|
).then( |
|
|
lambda viz: viz if viz else "", |
|
|
inputs=[current_visualization], |
|
|
outputs=[visualization_output] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
clear_context_btn.click(clear_context, None, chatbot, queue=False) |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
upload_button = gr.Button("Process & Index Documents", scale=2) |
|
|
flush_db_btn_doc = gr.Button("🗑️ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
flush_db_btn_doc.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
share=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True, |
|
|
debug=True |
|
|
) |