|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_groq import ChatGroq |
|
|
import plotly.express as px |
|
|
import time |
|
|
import plotly.io as pio |
|
|
import traceback |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import speech_recognition as sr |
|
|
from gtts import gTTS |
|
|
import re |
|
|
import importlib.util |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama3-8b-8192", |
|
|
temperature=0, |
|
|
max_tokens=None, |
|
|
timeout=None, |
|
|
max_retries=2, |
|
|
verbose=True, |
|
|
api_key=os.getenv("GROQ_API_KEY") |
|
|
) |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
|
|
|
current_plot = None |
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. |
|
|
|
|
|
Important guidelines: |
|
|
1. Use SQLite syntax (not PostgreSQL or MySQL) |
|
|
2. For date functions, use strftime() instead of EXTRACT |
|
|
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) |
|
|
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead |
|
|
4. For percentiles, use window functions or approximate methods |
|
|
5. Keep queries efficient and focused on answering the specific question |
|
|
6. Always use 'data_tab' as the table name |
|
|
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks |
|
|
|
|
|
Question: {question} |
|
|
""") |
|
|
|
|
|
|
|
|
interpret_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), |
|
|
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
visualization_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. |
|
|
|
|
|
Important guidelines for SQLite syntax: |
|
|
1. Use strftime() for date functions: |
|
|
- Year: strftime('%Y', date_column) |
|
|
- Month: strftime('%m', date_column) |
|
|
- Day: strftime('%d', date_column) |
|
|
- Hour: strftime('%H', date_column) |
|
|
|
|
|
2. For histograms and binning: |
|
|
- Use: CAST((column / bin_size) AS INT) * bin_size |
|
|
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin |
|
|
|
|
|
3. For box plots: |
|
|
- SQLite doesn't support PERCENTILE_CONT or window functions |
|
|
- Simply return the raw data column: SELECT column_name FROM data_tab |
|
|
- The application will calculate quartiles and outliers |
|
|
|
|
|
4. For heatmaps: |
|
|
- Return raw data for correlation analysis |
|
|
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab |
|
|
|
|
|
5. Always use 'data_tab' as the table name |
|
|
|
|
|
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks |
|
|
|
|
|
Question: {question} |
|
|
Visualization type: {viz_type} |
|
|
""") |
|
|
|
|
|
|
|
|
def clean_sql_query(query_text): |
|
|
"""Clean SQL query text by removing markdown formatting and comments""" |
|
|
|
|
|
if not query_text: |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
|
|
|
if "```" in query_text: |
|
|
|
|
|
pattern = r"```(?:sql)?(.*?)```" |
|
|
matches = re.findall(pattern, query_text, re.DOTALL) |
|
|
if matches: |
|
|
query_text = matches[0].strip() |
|
|
|
|
|
|
|
|
prefixes = [ |
|
|
"here is the sql query", |
|
|
"here is the sqlite query", |
|
|
"here is a query", |
|
|
"here's the sql query", |
|
|
"the sql query is", |
|
|
"sql query:" |
|
|
] |
|
|
|
|
|
for prefix in prefixes: |
|
|
if query_text.lower().startswith(prefix): |
|
|
|
|
|
sql_keywords = ["select", "with", "create", "insert", "update", "delete"] |
|
|
positions = [query_text.lower().find(keyword) for keyword in sql_keywords] |
|
|
positions = [pos for pos in positions if pos != -1] |
|
|
|
|
|
if positions: |
|
|
start_pos = min(positions) |
|
|
query_text = query_text[start_pos:] |
|
|
|
|
|
|
|
|
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) |
|
|
|
|
|
|
|
|
query_text = query_text.strip().rstrip(';') |
|
|
|
|
|
|
|
|
if not query_text.strip(): |
|
|
return "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
return query_text |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
viz_keywords = { |
|
|
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], |
|
|
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], |
|
|
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], |
|
|
'histogram': ['histogram', 'distribution of', 'frequency distribution'], |
|
|
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], |
|
|
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], |
|
|
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] |
|
|
} |
|
|
|
|
|
|
|
|
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) |
|
|
|
|
|
|
|
|
viz_type = None |
|
|
if is_visualization: |
|
|
for vtype, keywords in viz_keywords.items(): |
|
|
if any(keyword in query.lower() for keyword in keywords): |
|
|
viz_type = vtype |
|
|
break |
|
|
|
|
|
|
|
|
if current_context["file_type"] == "csv" and current_context["table_name"]: |
|
|
try: |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute(f"PRAGMA table_info({current_context['table_name']});") |
|
|
columns = [info[1] for info in cursor.fetchall()] |
|
|
columns_str = ", ".join(columns) |
|
|
|
|
|
|
|
|
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" |
|
|
|
|
|
|
|
|
if is_visualization and viz_type in ['box', 'heatmap']: |
|
|
|
|
|
if viz_type == 'box': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if numeric_cols: |
|
|
|
|
|
target_col = None |
|
|
for col in numeric_cols: |
|
|
if col.lower() in query.lower(): |
|
|
target_col = col |
|
|
break |
|
|
|
|
|
|
|
|
if not target_col and numeric_cols: |
|
|
target_col = numeric_cols[0] |
|
|
|
|
|
|
|
|
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" |
|
|
else: |
|
|
|
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
|
|
|
elif viz_type == 'heatmap': |
|
|
|
|
|
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(numeric_cols_query) |
|
|
numeric_cols = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
|
|
|
cols_to_use = numeric_cols[:10] |
|
|
cols_str = ", ".join(cols_to_use) |
|
|
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" |
|
|
else: |
|
|
|
|
|
sql_query = "SELECT * FROM data_tab LIMIT 10;" |
|
|
else: |
|
|
|
|
|
ai_msg = query_prompt | llm |
|
|
raw_sql_query = ai_msg.invoke({"question": question_with_context}).content.strip() |
|
|
|
|
|
|
|
|
sql_query = clean_sql_query(raw_sql_query) |
|
|
|
|
|
print(f"Generated SQL Query: {sql_query}") |
|
|
|
|
|
try: |
|
|
|
|
|
result_df = pd.read_sql_query(sql_query, conn) |
|
|
|
|
|
|
|
|
if not result_df.empty: |
|
|
data_summary = result_df.describe(include='all').to_string() |
|
|
|
|
|
|
|
|
if len(result_df) <= 10: |
|
|
data_summary += f"\n\nFull Results:\n{result_df.to_string()}" |
|
|
else: |
|
|
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}" |
|
|
else: |
|
|
data_summary = "No relevant data found." |
|
|
|
|
|
|
|
|
answer_chain = interpret_prompt | llm |
|
|
interpretation = answer_chain.invoke({ |
|
|
"question": query, |
|
|
"sql_query": sql_query, |
|
|
"data_summary": data_summary |
|
|
}).content.strip() |
|
|
|
|
|
|
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n" |
|
|
|
|
|
if not result_df.empty: |
|
|
if len(result_df) > 10: |
|
|
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n" |
|
|
else: |
|
|
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n" |
|
|
else: |
|
|
response += "**No results found.**\n\n" |
|
|
|
|
|
response += f"**Analysis:**\n{interpretation}" |
|
|
|
|
|
|
|
|
if is_visualization and not result_df.empty: |
|
|
try: |
|
|
print("Visualization requested, attempting to create plot...") |
|
|
|
|
|
|
|
|
fig_width = 1000 |
|
|
fig_height = 700 |
|
|
|
|
|
|
|
|
if viz_type == 'pie' and len(result_df) <= 20: |
|
|
|
|
|
category_col = result_df.columns[0] |
|
|
value_col = numeric_cols[0] if numeric_cols else result_df.columns[1] |
|
|
|
|
|
|
|
|
if len(numeric_cols) == len(result_df.columns): |
|
|
category_col = result_df.index.name or 'index' |
|
|
result_df = result_df.reset_index() |
|
|
|
|
|
fig = px.pie( |
|
|
result_df, |
|
|
names=category_col, |
|
|
values=value_col, |
|
|
title=f"Distribution of {value_col} by {category_col}", |
|
|
hole=0.3, |
|
|
color_discrete_sequence=px.colors.qualitative.Pastel |
|
|
) |
|
|
|
|
|
elif viz_type == 'histogram' and len(result_df.columns) > 0: |
|
|
|
|
|
|
|
|
|
|
|
if numeric_cols: |
|
|
x_col = numeric_cols[0] |
|
|
else: |
|
|
x_col = result_df.columns[0] |
|
|
|
|
|
|
|
|
if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns): |
|
|
|
|
|
bin_col = 'bin' if 'bin' in result_df.columns else 'range' |
|
|
count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1] |
|
|
|
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=bin_col, |
|
|
y=count_col, |
|
|
title=f"Histogram of {x_col}", |
|
|
labels={bin_col: x_col, count_col: 'Frequency'}, |
|
|
color_discrete_sequence=['#636EFA'] |
|
|
) |
|
|
else: |
|
|
|
|
|
fig = px.histogram( |
|
|
result_df, |
|
|
x=x_col, |
|
|
title=f"Distribution of {x_col}", |
|
|
nbins=20, |
|
|
marginal="box", |
|
|
color_discrete_sequence=['#636EFA'], |
|
|
opacity=0.8 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
bargap=0.1, |
|
|
xaxis_title=x_col, |
|
|
yaxis_title='Frequency', |
|
|
showlegend=True |
|
|
) |
|
|
|
|
|
elif viz_type == 'box' and numeric_cols: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_col = numeric_cols[0] |
|
|
|
|
|
|
|
|
fig = px.box( |
|
|
result_df, |
|
|
y=x_col, |
|
|
title=f"Box Plot of {x_col}", |
|
|
points="outliers", |
|
|
color_discrete_sequence=['#636EFA'] |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0] |
|
|
) |
|
|
|
|
|
elif viz_type == 'heatmap' and len(numeric_cols) >= 2: |
|
|
|
|
|
|
|
|
|
|
|
if len(numeric_cols) >= 3: |
|
|
|
|
|
|
|
|
clean_df = result_df[numeric_cols].dropna() |
|
|
|
|
|
if len(clean_df) > 1: |
|
|
corr_df = clean_df.corr() |
|
|
|
|
|
|
|
|
corr_df = corr_df.round(2) |
|
|
|
|
|
fig = px.imshow( |
|
|
corr_df, |
|
|
title="Correlation Heatmap", |
|
|
color_continuous_scale='RdBu_r', |
|
|
text_auto=True, |
|
|
aspect="auto", |
|
|
zmin=-1, zmax=1 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
xaxis_title="Features", |
|
|
yaxis_title="Features", |
|
|
coloraxis_colorbar=dict( |
|
|
title="Correlation", |
|
|
thicknessmode="pixels", thickness=20, |
|
|
lenmode="pixels", len=300, |
|
|
yanchor="top", y=1, |
|
|
ticks="outside" |
|
|
) |
|
|
) |
|
|
else: |
|
|
|
|
|
fig = px.bar( |
|
|
pd.DataFrame({'Message': ['Not enough data for heatmap']}), |
|
|
title="Cannot create heatmap - insufficient data" |
|
|
) |
|
|
else: |
|
|
|
|
|
x_col = numeric_cols[0] |
|
|
y_col = numeric_cols[1] |
|
|
|
|
|
|
|
|
fig = px.density_heatmap( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_col, |
|
|
title=f"Density Heatmap of {x_col} vs {y_col}", |
|
|
color_continuous_scale='Viridis', |
|
|
nbinsx=20, |
|
|
nbinsy=20, |
|
|
marginal_x="histogram", |
|
|
marginal_y="histogram" |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
xaxis_title=x_col, |
|
|
yaxis_title=y_col, |
|
|
coloraxis_colorbar=dict( |
|
|
title="Count", |
|
|
thicknessmode="pixels", thickness=20, |
|
|
lenmode="pixels", len=300, |
|
|
yanchor="top", y=1, |
|
|
ticks="outside" |
|
|
) |
|
|
) |
|
|
|
|
|
elif viz_type == 'scatter' and len(numeric_cols) >= 2: |
|
|
|
|
|
x_col = numeric_cols[0] |
|
|
y_col = numeric_cols[1] |
|
|
|
|
|
|
|
|
size_col = numeric_cols[2] if len(numeric_cols) > 2 else None |
|
|
|
|
|
|
|
|
if len(result_df.columns) > len(numeric_cols): |
|
|
|
|
|
categorical_cols = [col for col in result_df.columns if col not in numeric_cols] |
|
|
color_col = categorical_cols[0] if categorical_cols else None |
|
|
else: |
|
|
color_col = None |
|
|
|
|
|
|
|
|
fig = px.scatter( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_col, |
|
|
size=size_col, |
|
|
color=color_col, |
|
|
title=f"Relationship between {x_col} and {y_col}", |
|
|
opacity=0.7, |
|
|
size_max=15, |
|
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
|
) |
|
|
|
|
|
|
|
|
if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]): |
|
|
fig.update_layout( |
|
|
shapes=[ |
|
|
dict( |
|
|
type='line', |
|
|
xref='x', yref='y', |
|
|
x0=result_df[x_col].min(), |
|
|
y0=result_df[y_col].min(), |
|
|
x1=result_df[x_col].max(), |
|
|
y1=result_df[y_col].max(), |
|
|
line=dict(color='red', width=2, dash='dash') |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
xaxis_title=x_col, |
|
|
yaxis_title=y_col, |
|
|
showlegend=True, |
|
|
legend=dict( |
|
|
title=color_col if color_col else "", |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
) |
|
|
) |
|
|
|
|
|
elif viz_type == 'line': |
|
|
|
|
|
time_cols = [col for col in result_df.columns if any(time_word in col.lower() |
|
|
for time_word in ['date', 'time', 'month', 'year', 'day'])] |
|
|
|
|
|
if time_cols: |
|
|
x_col = time_cols[0] |
|
|
else: |
|
|
x_col = result_df.columns[0] |
|
|
|
|
|
|
|
|
y_cols = numeric_cols[:3] |
|
|
|
|
|
if not y_cols and len(result_df.columns) > 1: |
|
|
|
|
|
y_cols = [result_df.columns[1]] |
|
|
|
|
|
fig = px.line( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_cols, |
|
|
title="Time Series Analysis", |
|
|
markers=True, |
|
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
xaxis=dict( |
|
|
rangeslider=dict(visible=True), |
|
|
type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-' |
|
|
) |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
x_col = result_df.columns[0] |
|
|
|
|
|
|
|
|
if numeric_cols and x_col not in numeric_cols: |
|
|
y_cols = numeric_cols[:3] |
|
|
elif len(result_df.columns) > 1: |
|
|
y_cols = [result_df.columns[1]] |
|
|
else: |
|
|
y_cols = ['value'] |
|
|
result_df['value'] = 1 |
|
|
|
|
|
fig = px.bar( |
|
|
result_df, |
|
|
x=x_col, |
|
|
y=y_cols[0], |
|
|
title="Data Visualization", |
|
|
color_discrete_sequence=['#636EFA'] |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
autosize=True, |
|
|
width=fig_width, |
|
|
height=fig_height, |
|
|
margin=dict(l=50, r=50, b=100, t=100, pad=4), |
|
|
template="plotly_white", |
|
|
font=dict(size=14), |
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
|
|
plot_bgcolor='rgba(240,240,240,0.2)', |
|
|
paper_bgcolor='white' |
|
|
) |
|
|
|
|
|
|
|
|
img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2) |
|
|
encoded = base64.b64encode(img_bytes).decode("ascii") |
|
|
img_src = f"data:image/png;base64,{encoded}" |
|
|
|
|
|
|
|
|
response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />" |
|
|
|
|
|
|
|
|
response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**" |
|
|
|
|
|
except Exception as viz_error: |
|
|
print(f"Visualization error: {str(viz_error)}") |
|
|
traceback.print_exc() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}" |
|
|
|
|
|
conn.close() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"Error processing query: {str(e)}" |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
response = document_assistant.process_query(query) |
|
|
except Exception as e: |
|
|
response = f"Error processing document query: {str(e)}" |
|
|
|
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
response += f"\n\n(Query processed in {processing_time:.2f} seconds)" |
|
|
|
|
|
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
return "", history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
global current_context |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
conn.execute("PRAGMA synchronous = OFF") |
|
|
conn.execute("PRAGMA journal_mode = MEMORY") |
|
|
|
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
df.to_sql('data_tab', conn, if_exists='replace', index=False) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "csv", |
|
|
"file_name": file_name, |
|
|
"table_name": "data_tab" |
|
|
} |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(data_tab);") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM data_tab;") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
conn.close() |
|
|
|
|
|
file_info.append("β
CSV File Successfully Loaded") |
|
|
file_info.append(f"π Table Name: data_tab") |
|
|
file_info.append(f"π Source File: {file_name}") |
|
|
file_info.append(f"π Total Rows: {row_count:,}") |
|
|
file_info.append(f"π Columns: {', '.join(columns)}") |
|
|
|
|
|
except Exception as e: |
|
|
file_info.append(f"β Error loading CSV {file_name}: {str(e)}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
result = document_assistant.upload_document(file_path) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "pdf", |
|
|
"file_name": file_name, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info.append("β
Document Successfully Processed") |
|
|
file_info.append(f"π File: {file_name}") |
|
|
file_info.append(f"π Chunks: {result['chunks']}") |
|
|
file_info.append(result['message']) |
|
|
except Exception as e: |
|
|
file_info.append(f"β Error processing document {file_name}: {str(e)}") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
def list_documents(): |
|
|
"""List all indexed documents""" |
|
|
try: |
|
|
docs = document_assistant.get_all_documents() |
|
|
if not docs: |
|
|
return "No documents indexed yet." |
|
|
|
|
|
result = "Indexed Documents:\n\n" |
|
|
for doc in docs: |
|
|
result += f"- {doc['filename']} ({doc['file_type']})\n" |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Error listing documents: {str(e)}" |
|
|
|
|
|
def clear_context(): |
|
|
"""Clear the current context""" |
|
|
global current_context |
|
|
|
|
|
try: |
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return [{"role": "assistant", "content": "Context cleared. You can now upload new documents or CSV files."}] |
|
|
except Exception as e: |
|
|
return [{"role": "assistant", "content": f"Error clearing context: {str(e)}"}] |
|
|
|
|
|
def process_voice_input(audio_path): |
|
|
"""Process voice input and return transcribed text""" |
|
|
if audio_path is None: |
|
|
return "No audio recorded" |
|
|
|
|
|
try: |
|
|
|
|
|
r = sr.Recognizer() |
|
|
|
|
|
|
|
|
with sr.AudioFile(audio_path) as source: |
|
|
|
|
|
audio_data = r.record(source) |
|
|
|
|
|
|
|
|
text = r.recognize_google(audio_data) |
|
|
|
|
|
return text |
|
|
except sr.UnknownValueError: |
|
|
return "Could not understand audio" |
|
|
except sr.RequestError as e: |
|
|
return f"Error with speech recognition service: {e}" |
|
|
except Exception as e: |
|
|
return f"Error processing audio: {str(e)}" |
|
|
|
|
|
def text_to_speech_output(text): |
|
|
"""Convert text to speech""" |
|
|
if not text or len(text) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(text): |
|
|
if msg["role"] == "assistant": |
|
|
last_message = msg["content"] |
|
|
break |
|
|
|
|
|
if not last_message: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
clean_text = re.sub(r'<.*?>', '', last_message) |
|
|
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) |
|
|
clean_text = re.sub(r'\n\n', ' ', clean_text) |
|
|
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
|
temp_file.close() |
|
|
|
|
|
|
|
|
tts = gTTS(text=clean_text, lang='en', slow=False) |
|
|
tts.save(temp_file.name) |
|
|
|
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
print(f"Error generating speech: {str(e)}") |
|
|
return None |
|
|
|
|
|
def create_test_visualization(): |
|
|
"""Create a test visualization to verify plotting works""" |
|
|
try: |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
autosize=True, |
|
|
width=800, |
|
|
height=500 |
|
|
) |
|
|
|
|
|
return fig |
|
|
except Exception as e: |
|
|
print(f"Error creating test visualization: {str(e)}") |
|
|
return None |
|
|
|
|
|
def create_test_html_visualization(): |
|
|
"""Create a test HTML visualization to verify plotting works""" |
|
|
try: |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
html = pio.to_html(fig, full_html=False) |
|
|
|
|
|
return html |
|
|
except Exception as e: |
|
|
print(f"Error creating test HTML visualization: {str(e)}") |
|
|
return None |
|
|
|
|
|
def flush_databases(): |
|
|
"""Flush ChromaDB and SQLite databases""" |
|
|
result = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
|
|
|
for table in tables: |
|
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
result.append("β
SQLite database cleared successfully") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing SQLite database: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
success = document_assistant.reset_database() |
|
|
if success: |
|
|
result.append("β
ChromaDB cleared successfully") |
|
|
else: |
|
|
result.append("β οΈ ChromaDB reset may not have been complete") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing ChromaDB: {str(e)}") |
|
|
|
|
|
|
|
|
global current_context |
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
except NameError as e: |
|
|
if "response" in str(e): |
|
|
|
|
|
import backend.vector_db |
|
|
|
|
|
|
|
|
if hasattr(backend.vector_db, 'response'): |
|
|
delattr(backend.vector_db, 'response') |
|
|
|
|
|
|
|
|
importlib.reload(backend.vector_db) |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo: |
|
|
gr.Markdown("# π€ AI Document Analysis & Voice Assistant") |
|
|
gr.Markdown("Upload documents, ask questions, and get voice responses!") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.chatbot-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 10px 0; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
voice_btn = gr.Button("π€") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
clear_context_btn = gr.Button("Clear Context") |
|
|
|
|
|
audio_output = gr.Audio(label="Voice Response", type="filepath") |
|
|
|
|
|
|
|
|
voice_input = gr.Audio( |
|
|
label="Voice Input", |
|
|
type="filepath", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, [chatbot], queue=False) |
|
|
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot]) |
|
|
|
|
|
voice_btn.click( |
|
|
lambda: gr.update(visible=True), |
|
|
None, |
|
|
voice_input |
|
|
) |
|
|
|
|
|
voice_input.change( |
|
|
process_voice_input, |
|
|
inputs=[voice_input], |
|
|
outputs=[msg] |
|
|
) |
|
|
|
|
|
|
|
|
tts_btn = gr.Button("π Speak Response") |
|
|
tts_btn.click( |
|
|
text_to_speech_output, |
|
|
inputs=[chatbot], |
|
|
outputs=[audio_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
upload_button = gr.Button("Process & Index Documents", scale=2) |
|
|
flush_db_btn_doc = gr.Button("ποΈ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
flush_db_btn_doc.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
list_docs_button = gr.Button("List Indexed Documents") |
|
|
docs_output = gr.Textbox(label="Indexed Documents") |
|
|
|
|
|
list_docs_button.click( |
|
|
list_documents, |
|
|
inputs=[], |
|
|
outputs=[docs_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
with gr.Row(): |
|
|
gr.Markdown("## Database Management") |
|
|
flush_db_btn = gr.Button("ποΈ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
flush_result = gr.Textbox(label="Flush Result") |
|
|
|
|
|
flush_db_btn.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[flush_result] |
|
|
) |
|
|
|
|
|
gr.Markdown("## System Settings") |
|
|
api_key = gr.Textbox( |
|
|
label="Groq API Key", |
|
|
placeholder="Enter your Groq API key", |
|
|
type="password", |
|
|
value=os.getenv("GROQ_API_KEY", "") |
|
|
) |
|
|
save_btn = gr.Button("Save Settings") |
|
|
|
|
|
def save_settings(key): |
|
|
try: |
|
|
os.environ["GROQ_API_KEY"] = key |
|
|
return "Settings saved!" |
|
|
except Exception as e: |
|
|
return f"Error saving settings: {str(e)}" |
|
|
|
|
|
save_btn.click( |
|
|
save_settings, |
|
|
inputs=[api_key], |
|
|
outputs=[gr.Textbox(label="Status")] |
|
|
) |
|
|
|
|
|
gr.Markdown("## Debugging") |
|
|
test_viz_btn = gr.Button("Test Visualization") |
|
|
test_viz_output = gr.HTML(label="Test Visualization") |
|
|
|
|
|
test_viz_btn.click( |
|
|
create_test_html_visualization, |
|
|
inputs=[], |
|
|
outputs=[test_viz_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |