EDA_Assistant / app.py
TawsifDipto's picture
Upload app.py
196ceca verified
import os, io, asyncio, tempfile, traceback
import pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns
from dotenv import load_dotenv
import chainlit as cl
from google.genai import types
from PIL import Image
from io import BytesIO
from google import genai
import matplotlib
matplotlib.use('Agg') # Use a non-GUI backend for matplotlib
# Available models
AVAILABLE_MODELS = {
"Gemini 2.0 Flash Experimental": "gemini-2.0-flash-exp",
"Gemini 2.5 Pro": "gemini-2.5-pro",
"Gemini 2.5 Flash": "gemini-2.5-flash",
"Gemini 2.0 Image Generation": "gemini-2.0-flash-preview-image-generation",
"Gemini 2.0 Flash Lite": "gemini-2.0-flash-lite"
}
DEFAULT_MODEL = "gemini-2.0-flash-lite"
current_model = DEFAULT_MODEL
GEMINI_AVAILABLE = False
# Load environment variables
load_dotenv()
gemini_api_key = os.environ.get("GEMINI_API_KEY")
if not gemini_api_key:
raise ValueError("GEMINI_API_KEY not found in environment variables or .env file")
# Initialize Gemini client
client = genai.Client(api_key=gemini_api_key)
GEMINI_AVAILABLE = True
# Generation configuration
generation_config = types.GenerateContentConfig(
temperature=0,
max_output_tokens=8192,
response_mime_type="text/plain"
)
# Image generation config
image_generation_config = types.GenerateContentConfig(
response_modalities=["IMAGE", "TEXT"],
response_mime_type="text/plain"
)
def savefig(fig):
"""Save a matplotlib figure to a file."""
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmpfile:
fig.savefig(tmpfile.name, bbox_inches='tight', dpi = 150)
plt.close(fig)
return tmpfile.name
def df_to_string(df,max_rows=10):
"""Convert a DataFrame to a string representation."""
buf = io.StringIO()
df.info(buf = buf)
schema = buf.getvalue()
head = df.head(max_rows).to_markdown(index=False)
missing = df.isnull().sum()
missing = missing[missing > 0]
missing_info = "No missing values" if missing.empty else f"Missing values:\n{missing.to_string()}"
return f"### Schema:\n{schema}\n\n### Head:\n{head}\n\n### Missing:\n{missing_info}"
async def text_analysis(prompt_type,df_context):
if not GEMINI_AVAILABLE:
return "Gemini API is not available."
prompts = {
"plan": f"You are a data analyst. Suggest a concise data analysis plan for the following DataFrame:\n{df_context}",
"final": f"Summarize the analysis results for the following dataset:\n{df_context}",
}
try:
# model = genai.GenerativeModel(GEMINI_MODEL)
contents = [
genai.types.Content(
role="user",
parts=[genai.types.Part.from_text(text=prompts.get(prompt_type, ""))]
)
]
res = client.models.generate_content(
model = current_model ,
contents= contents,
config={
'temperature' : 0.0,
'max_output_tokens' : 1024,
}
)
if res.candidates and len(res.candidates) > 0:
candidate = res.candidates[0]
if candidate.content and candidate.content.parts:
return candidate.content.parts[0].text
else:
return "Gemini response blocked or empty."
else:
return "No response generated."
except Exception as e:
return f"Error during text analysis: {str(e)}\n{traceback.format_exc()}"
async def vision_analysis(img_paths):
if not GEMINI_AVAILABLE:
return "Gemini API is not available."
result = []
for title, img_path in img_paths:
try:
# Read image file
with open(img_path, "rb") as img_file:
img_data = img_file.read()
# Detect image MIME type based on file extension
if img_path.lower().endswith('.png'):
mime_type = "image/png"
elif img_path.lower().endswith(('.jpg', '.jpeg')):
mime_type = "image/jpeg"
elif img_path.lower().endswith('.webp'):
mime_type = "image/webp"
else:
mime_type = "image/jpeg" # default
# Create contents in the correct format
contents = [
genai.types.Content(
role="user",
parts=[
genai.types.Part.from_text(text=f"Analyze the image titled '{title}' and provide insights."),
genai.types.Part.from_bytes(data=img_data, mime_type=mime_type)
]
)
]
# Generate content using non-streaming API
response = client.models.generate_content(
model=current_model,
contents=contents,
config={
'temperature': 0.0,
'max_output_tokens': 1024,
}
)
# Extract text from response
if response.candidates and len(response.candidates) > 0:
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
result.append((title, candidate.content.parts[0].text))
else:
result.append((title, "Gemini response blocked."))
else:
result.append((title, "No response generated."))
except Exception as e:
result.append((title, f"Error: {str(e)}"))
return result
def generate_visuals(df):
"""Generate visualizations for the DataFrame."""
visuals = []
saved_images = []
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = [col for col in df.select_dtypes('object') if 1 < df[col].nunique() < 30]
try:
if numeric_cols:
# Histograms for numeric columns
for col in numeric_cols:
try:
fig, ax = plt.subplots()
df[col].dropna().hist(ax=ax, bins=30)
ax.set_title(f"Histogram of {col}")
ax.set_xlabel(col)
ax.set_ylabel("Frequency")
img_path = savefig(fig)
visuals.append(cl.Image(name=f"Histogram of {col}", path=img_path))
saved_images.append(img_path)
plt.close(fig)
except Exception as e:
print(f"Error generating histogram for {col}: {e}")
plt.close()
# Correlation heatmap
if len(numeric_cols) > 1:
try:
corr = df[numeric_cols].corr().round(2)
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr, annot=True, fmt=".2f", cmap='coolwarm', ax=ax)
ax.set_title("Correlation Heatmap")
img_path = savefig(fig)
visuals.append(cl.Image(name="Correlation Heatmap", path=img_path))
saved_images.append(img_path)
plt.close(fig)
except Exception as e:
print(f"Error generating correlation heatmap: {e}")
plt.close()
if categorical_cols:
# Bar plots for categorical columns
for col in categorical_cols:
try:
fig, ax = plt.subplots()
df[col].fillna("Missing").value_counts().head(20).plot(kind='bar', ax=ax)
ax.set_title(f"Bar Plot of {col} (Top 20 Categories)")
ax.set_xlabel(col)
ax.set_ylabel("Count")
img_path = savefig(fig)
visuals.append(cl.Image(name=f"Bar Plot of {col}", path=img_path))
saved_images.append(img_path)
plt.close(fig)
except Exception as e:
print(f"Error generating bar plot for {col}: {e}")
plt.close()
except Exception as e:
print(f"Unexpected error generating visuals: {e}")
plt.close('all')
return visuals, saved_images
async def cleanup_images(saved_images):
"""Clean up temporary image files."""
for img_path in saved_images:
try:
os.remove(img_path)
except Exception as e:
pass
async def process_csv_file(file_path):
"""Process uploaded CSV file and perform EDA"""
processing_msg = cl.Message(content="Processing your CSV file, please wait...")
await processing_msg.send()
try:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
content = f.read()
df = pd.read_csv(io.StringIO(content))
if df.empty:
processing_msg.content="The uploaded file is empty or invalid."
await processing_msg.update()
return
cl.user_session.set("df", df)
info = df_to_string(df)
await cl.Message(content=info).send()
if GEMINI_AVAILABLE:
plan = await text_analysis("plan", info)
await cl.Message(content=f"### Analysis Plan: \n{plan}").send()
visuals, saved_images = generate_visuals(df)
batch_size = 7
for i in range(0, len(visuals), batch_size):
batch = visuals[i:i+batch_size]
if batch: # Only send if batch is not empty
await cl.Message(
content=f"**Generated Visualizations (batch {i//batch_size+1}):**",
elements=batch
).send()
visuals = [(img.name, img.path) for img in visuals]
if GEMINI_AVAILABLE:
insights = await vision_analysis(visuals)
for title, insight in insights:
await cl.Message(content=f"**Insights for {title}:**\n{insight}").send()
final = await text_analysis("final", info)
await cl.Message(content=f"### Final Summary:\n{final}").send()
processing_msg.content="CSV analysis complete! You can now continue chatting or upload another file."
await processing_msg.update()
await cleanup_images([path for _, path in visuals])
except Exception as e:
processing_msg.content=f"An error occurred during CSV processing: {str(e)}"
await processing_msg.update()
print(f"Error: {e}\n{traceback.format_exc()}")
@cl.on_chat_start
async def start_chat():
cl.user_session.set("current_model", DEFAULT_MODEL)
cl.user_session.set("generation_config", generation_config)
await cl.ChatSettings([
cl.input_widget.Select(
id="model_selector",
label="Select AI Model",
values=list(AVAILABLE_MODELS.keys()),
initial_value=[k for k, v in AVAILABLE_MODELS.items() if v == DEFAULT_MODEL][0]
)
]).send()
welcome = """
# Gemini EDA Assistant
Welcome to the **Gemini EDA Assistant** with Dataframe analysis and image generation support!
## Getting Started
You can start chatting immediately! The assistant is ready to help with various tasks.
### Available Models
- **Gemini 2.0 Flash Experimental**: Lightweight and experimental
- **Gemini 2.5 Pro**: Advanced reasoning capabilities
- **Gemini 2.5 Flash**: Balanced performance
- **Gemini 2.0 Image Generation**: Create images from text prompts
### Features
- **Normal Chat**: Ask questions, get help with coding, writing, analysis, etc.
- **Image Generation**: Start your prompt with "/image" or "generate an image of"
- **CSV Analysis**: Upload a CSV file anytime during our conversation for automated EDA
### Commands
- `/upload` - Upload a CSV file for analysis
- `/image [description]` - Generate an image
---
*Ready to chat! Feel free to ask questions or upload a CSV file for analysis.*
"""
await cl.Message(content=welcome.strip()).send()
@cl.on_settings_update
async def setup_chat_settings(settings):
selected_model_name = settings["model_selector"]
selected_model = AVAILABLE_MODELS[selected_model_name]
cl.user_session.set("current_model", selected_model)
cl.user_session.set("generation_config", generation_config)
await cl.Message(
content=f"**Settings Updated** Now using: `{selected_model_name}` model."
).send()
async def handle_image_generation(prompt: str):
"""Handle image generation requests"""
msg = cl.Message(author="Gemini Image Generator", content="Generating your image...")
await msg.send()
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=prompt)]
)
]
try:
stream = client.models.generate_content_stream(
model="gemini-2.0-flash-preview-image-generation",
contents=contents,
config=image_generation_config
)
for chunk in stream:
if (chunk.candidates and
chunk.candidates[0].content and
chunk.candidates[0].content.parts):
for part in chunk.candidates[0].content.parts:
if hasattr(part, "inline_data") and part.inline_data:
# Handle image data
image_data = part.inline_data.data
image = Image.open(BytesIO(image_data))
# Create Chainlit image element
image_element = cl.Image(
name="generated-image",
display="inline",
size="large",
content=image_data
)
await msg.remove()
await cl.Message(
author="Gemini Image Generator",
content=f"Here's your generated image:",
elements=[image_element]
).send()
return
elif hasattr(part, "text"):
await msg.stream_token(part.text)
await msg.update()
except Exception as e:
error_msg = f"\n**Error**: Unable to generate image. Details: {str(e)}"
await msg.stream_token(error_msg)
print(f"Error: {e}")
@cl.on_message
async def main(message: cl.Message):
current_model = cl.user_session.get("current_model", DEFAULT_MODEL)
config = cl.user_session.get("generation_config", generation_config)
model_display_name = [k for k, v in AVAILABLE_MODELS.items() if v == current_model][0]
# Check if user wants to upload a CSV file
if message.content.lower().strip() in ["/upload", "upload csv", "upload a csv", "analyze csv"]:
files = await cl.AskFileMessage(
content="Please upload a CSV file for analysis.",
accept=["text/csv"],
max_files=1,
max_size_mb=50
).send()
if files and len(files) > 0:
await process_csv_file(files[0].path)
else:
await cl.Message(content="No file uploaded. You can try again anytime by typing `/upload`.").send()
return
# Handle file attachments (CSV files)
if message.elements:
csv_files = [file for file in message.elements if hasattr(file, 'path') and file.path.lower().endswith('.csv')]
if csv_files:
await process_csv_file(csv_files[0].path)
return
# Check if this is an image generation request
if message.content.lower().startswith(("/image", "generate an image of")):
await handle_image_generation(message.content)
return
# Normal chat handling
msg = cl.Message(author=model_display_name, content="")
await msg.send()
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=message.content)]
)
]
full_response = ""
try:
stream = client.models.generate_content_stream(
model=current_model,
contents=contents,
config=config
)
for chunk in stream:
text = getattr(chunk, "text", None)
if text:
full_response += text
await msg.stream_token(text)
elif getattr(chunk, "candidates", None):
for candidate in chunk.candidates:
parts = getattr(candidate.content, "parts", [])
for part in parts:
if hasattr(part, "text"):
full_response += part.text
await msg.stream_token(part.text)
except Exception as e:
error_msg = f"\n**Error**: Unable to process request with {model_display_name}. Details: {str(e)}"
await msg.stream_token(error_msg)
print(f"Error: {e}")
await msg.stream_token(f"\n\n---\n**Model**: {model_display_name}")
await msg.update()