Spaces:
Sleeping
Sleeping
| import os, io, asyncio, tempfile, traceback | |
| import pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns | |
| from dotenv import load_dotenv | |
| import chainlit as cl | |
| from google.genai import types | |
| from PIL import Image | |
| from io import BytesIO | |
| from google import genai | |
| import matplotlib | |
| matplotlib.use('Agg') # Use a non-GUI backend for matplotlib | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "Gemini 2.0 Flash Experimental": "gemini-2.0-flash-exp", | |
| "Gemini 2.5 Pro": "gemini-2.5-pro", | |
| "Gemini 2.5 Flash": "gemini-2.5-flash", | |
| "Gemini 2.0 Image Generation": "gemini-2.0-flash-preview-image-generation", | |
| "Gemini 2.0 Flash Lite": "gemini-2.0-flash-lite" | |
| } | |
| DEFAULT_MODEL = "gemini-2.0-flash-lite" | |
| current_model = DEFAULT_MODEL | |
| GEMINI_AVAILABLE = False | |
| # Load environment variables | |
| load_dotenv() | |
| gemini_api_key = os.environ.get("GEMINI_API_KEY") | |
| if not gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY not found in environment variables or .env file") | |
| # Initialize Gemini client | |
| client = genai.Client(api_key=gemini_api_key) | |
| GEMINI_AVAILABLE = True | |
| # Generation configuration | |
| generation_config = types.GenerateContentConfig( | |
| temperature=0, | |
| max_output_tokens=8192, | |
| response_mime_type="text/plain" | |
| ) | |
| # Image generation config | |
| image_generation_config = types.GenerateContentConfig( | |
| response_modalities=["IMAGE", "TEXT"], | |
| response_mime_type="text/plain" | |
| ) | |
| def savefig(fig): | |
| """Save a matplotlib figure to a file.""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmpfile: | |
| fig.savefig(tmpfile.name, bbox_inches='tight', dpi = 150) | |
| plt.close(fig) | |
| return tmpfile.name | |
| def df_to_string(df,max_rows=10): | |
| """Convert a DataFrame to a string representation.""" | |
| buf = io.StringIO() | |
| df.info(buf = buf) | |
| schema = buf.getvalue() | |
| head = df.head(max_rows).to_markdown(index=False) | |
| missing = df.isnull().sum() | |
| missing = missing[missing > 0] | |
| missing_info = "No missing values" if missing.empty else f"Missing values:\n{missing.to_string()}" | |
| return f"### Schema:\n{schema}\n\n### Head:\n{head}\n\n### Missing:\n{missing_info}" | |
| async def text_analysis(prompt_type,df_context): | |
| if not GEMINI_AVAILABLE: | |
| return "Gemini API is not available." | |
| prompts = { | |
| "plan": f"You are a data analyst. Suggest a concise data analysis plan for the following DataFrame:\n{df_context}", | |
| "final": f"Summarize the analysis results for the following dataset:\n{df_context}", | |
| } | |
| try: | |
| # model = genai.GenerativeModel(GEMINI_MODEL) | |
| contents = [ | |
| genai.types.Content( | |
| role="user", | |
| parts=[genai.types.Part.from_text(text=prompts.get(prompt_type, ""))] | |
| ) | |
| ] | |
| res = client.models.generate_content( | |
| model = current_model , | |
| contents= contents, | |
| config={ | |
| 'temperature' : 0.0, | |
| 'max_output_tokens' : 1024, | |
| } | |
| ) | |
| if res.candidates and len(res.candidates) > 0: | |
| candidate = res.candidates[0] | |
| if candidate.content and candidate.content.parts: | |
| return candidate.content.parts[0].text | |
| else: | |
| return "Gemini response blocked or empty." | |
| else: | |
| return "No response generated." | |
| except Exception as e: | |
| return f"Error during text analysis: {str(e)}\n{traceback.format_exc()}" | |
| async def vision_analysis(img_paths): | |
| if not GEMINI_AVAILABLE: | |
| return "Gemini API is not available." | |
| result = [] | |
| for title, img_path in img_paths: | |
| try: | |
| # Read image file | |
| with open(img_path, "rb") as img_file: | |
| img_data = img_file.read() | |
| # Detect image MIME type based on file extension | |
| if img_path.lower().endswith('.png'): | |
| mime_type = "image/png" | |
| elif img_path.lower().endswith(('.jpg', '.jpeg')): | |
| mime_type = "image/jpeg" | |
| elif img_path.lower().endswith('.webp'): | |
| mime_type = "image/webp" | |
| else: | |
| mime_type = "image/jpeg" # default | |
| # Create contents in the correct format | |
| contents = [ | |
| genai.types.Content( | |
| role="user", | |
| parts=[ | |
| genai.types.Part.from_text(text=f"Analyze the image titled '{title}' and provide insights."), | |
| genai.types.Part.from_bytes(data=img_data, mime_type=mime_type) | |
| ] | |
| ) | |
| ] | |
| # Generate content using non-streaming API | |
| response = client.models.generate_content( | |
| model=current_model, | |
| contents=contents, | |
| config={ | |
| 'temperature': 0.0, | |
| 'max_output_tokens': 1024, | |
| } | |
| ) | |
| # Extract text from response | |
| if response.candidates and len(response.candidates) > 0: | |
| candidate = response.candidates[0] | |
| if candidate.content and candidate.content.parts: | |
| result.append((title, candidate.content.parts[0].text)) | |
| else: | |
| result.append((title, "Gemini response blocked.")) | |
| else: | |
| result.append((title, "No response generated.")) | |
| except Exception as e: | |
| result.append((title, f"Error: {str(e)}")) | |
| return result | |
| def generate_visuals(df): | |
| """Generate visualizations for the DataFrame.""" | |
| visuals = [] | |
| saved_images = [] | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| categorical_cols = [col for col in df.select_dtypes('object') if 1 < df[col].nunique() < 30] | |
| try: | |
| if numeric_cols: | |
| # Histograms for numeric columns | |
| for col in numeric_cols: | |
| try: | |
| fig, ax = plt.subplots() | |
| df[col].dropna().hist(ax=ax, bins=30) | |
| ax.set_title(f"Histogram of {col}") | |
| ax.set_xlabel(col) | |
| ax.set_ylabel("Frequency") | |
| img_path = savefig(fig) | |
| visuals.append(cl.Image(name=f"Histogram of {col}", path=img_path)) | |
| saved_images.append(img_path) | |
| plt.close(fig) | |
| except Exception as e: | |
| print(f"Error generating histogram for {col}: {e}") | |
| plt.close() | |
| # Correlation heatmap | |
| if len(numeric_cols) > 1: | |
| try: | |
| corr = df[numeric_cols].corr().round(2) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap(corr, annot=True, fmt=".2f", cmap='coolwarm', ax=ax) | |
| ax.set_title("Correlation Heatmap") | |
| img_path = savefig(fig) | |
| visuals.append(cl.Image(name="Correlation Heatmap", path=img_path)) | |
| saved_images.append(img_path) | |
| plt.close(fig) | |
| except Exception as e: | |
| print(f"Error generating correlation heatmap: {e}") | |
| plt.close() | |
| if categorical_cols: | |
| # Bar plots for categorical columns | |
| for col in categorical_cols: | |
| try: | |
| fig, ax = plt.subplots() | |
| df[col].fillna("Missing").value_counts().head(20).plot(kind='bar', ax=ax) | |
| ax.set_title(f"Bar Plot of {col} (Top 20 Categories)") | |
| ax.set_xlabel(col) | |
| ax.set_ylabel("Count") | |
| img_path = savefig(fig) | |
| visuals.append(cl.Image(name=f"Bar Plot of {col}", path=img_path)) | |
| saved_images.append(img_path) | |
| plt.close(fig) | |
| except Exception as e: | |
| print(f"Error generating bar plot for {col}: {e}") | |
| plt.close() | |
| except Exception as e: | |
| print(f"Unexpected error generating visuals: {e}") | |
| plt.close('all') | |
| return visuals, saved_images | |
| async def cleanup_images(saved_images): | |
| """Clean up temporary image files.""" | |
| for img_path in saved_images: | |
| try: | |
| os.remove(img_path) | |
| except Exception as e: | |
| pass | |
| async def process_csv_file(file_path): | |
| """Process uploaded CSV file and perform EDA""" | |
| processing_msg = cl.Message(content="Processing your CSV file, please wait...") | |
| await processing_msg.send() | |
| try: | |
| with open(file_path, "r", encoding="utf-8", errors="replace") as f: | |
| content = f.read() | |
| df = pd.read_csv(io.StringIO(content)) | |
| if df.empty: | |
| processing_msg.content="The uploaded file is empty or invalid." | |
| await processing_msg.update() | |
| return | |
| cl.user_session.set("df", df) | |
| info = df_to_string(df) | |
| await cl.Message(content=info).send() | |
| if GEMINI_AVAILABLE: | |
| plan = await text_analysis("plan", info) | |
| await cl.Message(content=f"### Analysis Plan: \n{plan}").send() | |
| visuals, saved_images = generate_visuals(df) | |
| batch_size = 7 | |
| for i in range(0, len(visuals), batch_size): | |
| batch = visuals[i:i+batch_size] | |
| if batch: # Only send if batch is not empty | |
| await cl.Message( | |
| content=f"**Generated Visualizations (batch {i//batch_size+1}):**", | |
| elements=batch | |
| ).send() | |
| visuals = [(img.name, img.path) for img in visuals] | |
| if GEMINI_AVAILABLE: | |
| insights = await vision_analysis(visuals) | |
| for title, insight in insights: | |
| await cl.Message(content=f"**Insights for {title}:**\n{insight}").send() | |
| final = await text_analysis("final", info) | |
| await cl.Message(content=f"### Final Summary:\n{final}").send() | |
| processing_msg.content="CSV analysis complete! You can now continue chatting or upload another file." | |
| await processing_msg.update() | |
| await cleanup_images([path for _, path in visuals]) | |
| except Exception as e: | |
| processing_msg.content=f"An error occurred during CSV processing: {str(e)}" | |
| await processing_msg.update() | |
| print(f"Error: {e}\n{traceback.format_exc()}") | |
| async def start_chat(): | |
| cl.user_session.set("current_model", DEFAULT_MODEL) | |
| cl.user_session.set("generation_config", generation_config) | |
| await cl.ChatSettings([ | |
| cl.input_widget.Select( | |
| id="model_selector", | |
| label="Select AI Model", | |
| values=list(AVAILABLE_MODELS.keys()), | |
| initial_value=[k for k, v in AVAILABLE_MODELS.items() if v == DEFAULT_MODEL][0] | |
| ) | |
| ]).send() | |
| welcome = """ | |
| # Gemini EDA Assistant | |
| Welcome to the **Gemini EDA Assistant** with Dataframe analysis and image generation support! | |
| ## Getting Started | |
| You can start chatting immediately! The assistant is ready to help with various tasks. | |
| ### Available Models | |
| - **Gemini 2.0 Flash Experimental**: Lightweight and experimental | |
| - **Gemini 2.5 Pro**: Advanced reasoning capabilities | |
| - **Gemini 2.5 Flash**: Balanced performance | |
| - **Gemini 2.0 Image Generation**: Create images from text prompts | |
| ### Features | |
| - **Normal Chat**: Ask questions, get help with coding, writing, analysis, etc. | |
| - **Image Generation**: Start your prompt with "/image" or "generate an image of" | |
| - **CSV Analysis**: Upload a CSV file anytime during our conversation for automated EDA | |
| ### Commands | |
| - `/upload` - Upload a CSV file for analysis | |
| - `/image [description]` - Generate an image | |
| --- | |
| *Ready to chat! Feel free to ask questions or upload a CSV file for analysis.* | |
| """ | |
| await cl.Message(content=welcome.strip()).send() | |
| async def setup_chat_settings(settings): | |
| selected_model_name = settings["model_selector"] | |
| selected_model = AVAILABLE_MODELS[selected_model_name] | |
| cl.user_session.set("current_model", selected_model) | |
| cl.user_session.set("generation_config", generation_config) | |
| await cl.Message( | |
| content=f"**Settings Updated** Now using: `{selected_model_name}` model." | |
| ).send() | |
| async def handle_image_generation(prompt: str): | |
| """Handle image generation requests""" | |
| msg = cl.Message(author="Gemini Image Generator", content="Generating your image...") | |
| await msg.send() | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[types.Part.from_text(text=prompt)] | |
| ) | |
| ] | |
| try: | |
| stream = client.models.generate_content_stream( | |
| model="gemini-2.0-flash-preview-image-generation", | |
| contents=contents, | |
| config=image_generation_config | |
| ) | |
| for chunk in stream: | |
| if (chunk.candidates and | |
| chunk.candidates[0].content and | |
| chunk.candidates[0].content.parts): | |
| for part in chunk.candidates[0].content.parts: | |
| if hasattr(part, "inline_data") and part.inline_data: | |
| # Handle image data | |
| image_data = part.inline_data.data | |
| image = Image.open(BytesIO(image_data)) | |
| # Create Chainlit image element | |
| image_element = cl.Image( | |
| name="generated-image", | |
| display="inline", | |
| size="large", | |
| content=image_data | |
| ) | |
| await msg.remove() | |
| await cl.Message( | |
| author="Gemini Image Generator", | |
| content=f"Here's your generated image:", | |
| elements=[image_element] | |
| ).send() | |
| return | |
| elif hasattr(part, "text"): | |
| await msg.stream_token(part.text) | |
| await msg.update() | |
| except Exception as e: | |
| error_msg = f"\n**Error**: Unable to generate image. Details: {str(e)}" | |
| await msg.stream_token(error_msg) | |
| print(f"Error: {e}") | |
| async def main(message: cl.Message): | |
| current_model = cl.user_session.get("current_model", DEFAULT_MODEL) | |
| config = cl.user_session.get("generation_config", generation_config) | |
| model_display_name = [k for k, v in AVAILABLE_MODELS.items() if v == current_model][0] | |
| # Check if user wants to upload a CSV file | |
| if message.content.lower().strip() in ["/upload", "upload csv", "upload a csv", "analyze csv"]: | |
| files = await cl.AskFileMessage( | |
| content="Please upload a CSV file for analysis.", | |
| accept=["text/csv"], | |
| max_files=1, | |
| max_size_mb=50 | |
| ).send() | |
| if files and len(files) > 0: | |
| await process_csv_file(files[0].path) | |
| else: | |
| await cl.Message(content="No file uploaded. You can try again anytime by typing `/upload`.").send() | |
| return | |
| # Handle file attachments (CSV files) | |
| if message.elements: | |
| csv_files = [file for file in message.elements if hasattr(file, 'path') and file.path.lower().endswith('.csv')] | |
| if csv_files: | |
| await process_csv_file(csv_files[0].path) | |
| return | |
| # Check if this is an image generation request | |
| if message.content.lower().startswith(("/image", "generate an image of")): | |
| await handle_image_generation(message.content) | |
| return | |
| # Normal chat handling | |
| msg = cl.Message(author=model_display_name, content="") | |
| await msg.send() | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[types.Part.from_text(text=message.content)] | |
| ) | |
| ] | |
| full_response = "" | |
| try: | |
| stream = client.models.generate_content_stream( | |
| model=current_model, | |
| contents=contents, | |
| config=config | |
| ) | |
| for chunk in stream: | |
| text = getattr(chunk, "text", None) | |
| if text: | |
| full_response += text | |
| await msg.stream_token(text) | |
| elif getattr(chunk, "candidates", None): | |
| for candidate in chunk.candidates: | |
| parts = getattr(candidate.content, "parts", []) | |
| for part in parts: | |
| if hasattr(part, "text"): | |
| full_response += part.text | |
| await msg.stream_token(part.text) | |
| except Exception as e: | |
| error_msg = f"\n**Error**: Unable to process request with {model_display_name}. Details: {str(e)}" | |
| await msg.stream_token(error_msg) | |
| print(f"Error: {e}") | |
| await msg.stream_token(f"\n\n---\n**Model**: {model_display_name}") | |
| await msg.update() |