| | import streamlit as st |
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import os |
| | import base64 |
| | import io |
| | from groq import Groq |
| | from pydantic import BaseModel, Field |
| | from typing import Dict, List, Optional |
| | from langchain.tools import tool |
| | from langchain.agents import initialize_agent, AgentType |
| | from scipy.stats import ttest_ind, f_oneway |
| | from statsmodels.tsa.seasonal import seasonal_decompose |
| | from statsmodels.tsa.stattools import adfuller |
| | from langchain.prompts import PromptTemplate |
| |
|
| | |
| | client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
| |
|
| | class ResearchInput(BaseModel): |
| | """Base schema for research tool inputs""" |
| | data_key: str = Field(..., description="Session state key containing DataFrame") |
| | columns: Optional[List[str]] = Field(None, description="List of columns to analyze") |
| |
|
| | class TemporalAnalysisInput(ResearchInput): |
| | """Schema for temporal analysis""" |
| | time_col: str = Field(..., description="Name of timestamp column") |
| | value_col: str = Field(..., description="Name of value column to analyze") |
| |
|
| | class HypothesisInput(ResearchInput): |
| | """Schema for hypothesis testing""" |
| | group_col: str = Field(..., description="Categorical column defining groups") |
| | value_col: str = Field(..., description="Numerical column to compare") |
| |
|
| | class GroqResearcher: |
| | """Advanced AI Research Engine using Groq""" |
| | def __init__(self, model_name="mixtral-8x7b-32768"): |
| | self.model_name = model_name |
| | self.system_template = """You are a senior data scientist at a research institution. |
| | Analyze this dataset with rigorous statistical methods and provide academic-quality insights: |
| | {dataset_info} |
| | |
| | User Question: {query} |
| | |
| | Required Format: |
| | - Executive Summary (1 paragraph) |
| | - Methodology (bullet points) |
| | - Key Findings (numbered list) |
| | - Limitations |
| | - Recommended Next Steps""" |
| |
|
| | def research(self, query: str, data: pd.DataFrame) -> str: |
| | """Conduct academic-level analysis using Groq""" |
| | try: |
| | dataset_info = f""" |
| | Dataset Dimensions: {data.shape} |
| | Variables: {', '.join(data.columns)} |
| | Temporal Coverage: {data.select_dtypes(include='datetime').columns.tolist()} |
| | Missing Values: {data.isnull().sum().to_dict()} |
| | """ |
| | |
| | prompt = PromptTemplate.from_template(self.system_template).format( |
| | dataset_info=dataset_info, |
| | query=query |
| | ) |
| | |
| | completion = client.chat.completions.create( |
| | messages=[ |
| | {"role": "system", "content": "You are a research AI assistant"}, |
| | {"role": "user", "content": prompt} |
| | ], |
| | model=self.model_name, |
| | temperature=0.2, |
| | max_tokens=4096, |
| | stream=False |
| | ) |
| | |
| | return completion.choices[0].message.content |
| | |
| | except Exception as e: |
| | return f"Research Error: {str(e)}" |
| |
|
| | @tool(args_schema=ResearchInput) |
| | def advanced_eda(data_key: str) -> Dict: |
| | """Comprehensive Exploratory Data Analysis with Statistical Profiling""" |
| | try: |
| | data = st.session_state[data_key] |
| | analysis = { |
| | "dimensionality": { |
| | "rows": len(data), |
| | "columns": list(data.columns), |
| | "memory_usage": f"{data.memory_usage().sum() / 1e6:.2f} MB" |
| | }, |
| | "statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(), |
| | "temporal_analysis": { |
| | "date_ranges": { |
| | col: { |
| | "min": data[col].min(), |
| | "max": data[col].max() |
| | } for col in data.select_dtypes(include='datetime').columns |
| | } |
| | }, |
| | "data_quality": { |
| | "missing_values": data.isnull().sum().to_dict(), |
| | "duplicates": data.duplicated().sum(), |
| | "cardinality": { |
| | col: data[col].nunique() for col in data.columns |
| | } |
| | } |
| | } |
| | return analysis |
| | except Exception as e: |
| | return {"error": f"EDA Failed: {str(e)}"} |
| |
|
| | @tool(args_schema=ResearchInput) |
| | def visualize_distributions(data_key: str, columns: List[str]) -> str: |
| | """Generate publication-quality distribution visualizations""" |
| | try: |
| | data = st.session_state[data_key] |
| | plt.figure(figsize=(12, 6)) |
| | for i, col in enumerate(columns, 1): |
| | plt.subplot(1, len(columns), i) |
| | sns.histplot(data[col], kde=True, stat="density") |
| | plt.title(f'Distribution of {col}', fontsize=10) |
| | plt.xticks(fontsize=8) |
| | plt.yticks(fontsize=8) |
| | plt.tight_layout() |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', dpi=300, bbox_inches='tight') |
| | plt.close() |
| | return base64.b64encode(buf.getvalue()).decode() |
| | except Exception as e: |
| | return f"Visualization Error: {str(e)}" |
| |
|
| | @tool(args_schema=TemporalAnalysisInput) |
| | def temporal_analysis(data_key: str, time_col: str, value_col: str) -> Dict: |
| | """Time Series Decomposition and Trend Analysis""" |
| | try: |
| | data = st.session_state[data_key] |
| | ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col] |
| | |
| | decomposition = seasonal_decompose(ts_data, period=365) |
| | |
| | plt.figure(figsize=(12, 8)) |
| | decomposition.plot() |
| | plt.tight_layout() |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png') |
| | plt.close() |
| | plot_data = base64.b64encode(buf.getvalue()).decode() |
| | |
| | return { |
| | "trend_statistics": { |
| | "stationarity": adfuller(ts_data)[1], |
| | "seasonality_strength": max(decomposition.seasonal) |
| | }, |
| | "visualization": plot_data |
| | } |
| | except Exception as e: |
| | return {"error": f"Temporal Analysis Failed: {str(e)}"} |
| |
|
| | @tool(args_schema=HypothesisInput) |
| | def hypothesis_testing(data_key: str, group_col: str, value_col: str) -> Dict: |
| | """Statistical Hypothesis Testing with Automated Assumption Checking""" |
| | try: |
| | data = st.session_state[data_key] |
| | groups = data[group_col].unique() |
| | |
| | if len(groups) < 2: |
| | return {"error": "Insufficient groups for comparison"} |
| | |
| | if len(groups) == 2: |
| | group_data = [data[data[group_col] == g][value_col] for g in groups] |
| | stat, p = ttest_ind(*group_data) |
| | test_type = "Independent t-test" |
| | else: |
| | group_data = [data[data[group_col] == g][value_col] for g in groups] |
| | stat, p = f_oneway(*group_data) |
| | test_type = "ANOVA" |
| | |
| | return { |
| | "test_type": test_type, |
| | "test_statistic": stat, |
| | "p_value": p, |
| | "effect_size": { |
| | "cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt( |
| | (group_data[0].var() + group_data[1].var())/2 |
| | ) if len(groups) == 2 else None |
| | }, |
| | "interpretation": interpret_p_value(p) |
| | } |
| | except Exception as e: |
| | return {"error": f"Hypothesis Testing Failed: {str(e)}"} |
| |
|
| | def interpret_p_value(p: float) -> str: |
| | """Scientific interpretation of p-values""" |
| | if p < 0.001: return "Very strong evidence against H0" |
| | elif p < 0.01: return "Strong evidence against H0" |
| | elif p < 0.05: return "Evidence against H0" |
| | elif p < 0.1: return "Weak evidence against H0" |
| | else: return "No significant evidence against H0" |
| |
|
| | def main(): |
| | st.set_page_config(page_title="AI Research Lab", layout="wide") |
| | st.title("🧪 Advanced AI Research Laboratory") |
| | |
| | |
| | if 'data' not in st.session_state: |
| | st.session_state.data = None |
| | if 'researcher' not in st.session_state: |
| | st.session_state.researcher = GroqResearcher() |
| | |
| | |
| | with st.sidebar: |
| | st.header("🔬 Data Management") |
| | uploaded_file = st.file_uploader("Upload research dataset", type=["csv", "parquet"]) |
| | if uploaded_file: |
| | with st.spinner("Initializing dataset..."): |
| | try: |
| | st.session_state.data = pd.read_csv(uploaded_file) |
| | st.success(f"Loaded {len(st.session_state.data):,} research observations") |
| | except Exception as e: |
| | st.error(f"Error loading dataset: {e}") |
| | |
| | |
| | if st.session_state.data is not None: |
| | col1, col2 = st.columns([1, 3]) |
| | |
| | with col1: |
| | st.subheader("Dataset Metadata") |
| | st.json({ |
| | "Variables": list(st.session_state.data.columns), |
| | "Time Range": { |
| | col: { |
| | "min": st.session_state.data[col].min(), |
| | "max": st.session_state.data[col].max() |
| | } for col in st.session_state.data.select_dtypes(include='datetime').columns |
| | }, |
| | "Size": f"{st.session_state.data.memory_usage().sum() / 1e6:.2f} MB" |
| | }) |
| | |
| | with col2: |
| | analysis_tab, research_tab = st.tabs(["Automated Analysis", "Custom Research"]) |
| | |
| | with analysis_tab: |
| | analysis_type = st.selectbox("Select Analysis Mode", [ |
| | "Exploratory Data Analysis", |
| | "Temporal Pattern Analysis", |
| | "Comparative Statistics", |
| | "Distribution Analysis" |
| | ]) |
| | |
| | if analysis_type == "Exploratory Data Analysis": |
| | eda_result = advanced_eda.invoke({"data_key": "data"}) |
| | st.subheader("Data Quality Report") |
| | st.json(eda_result) |
| | |
| | elif analysis_type == "Temporal Pattern Analysis": |
| | time_col = st.selectbox("Temporal Variable", |
| | st.session_state.data.select_dtypes(include='datetime').columns) |
| | value_col = st.selectbox("Analysis Variable", |
| | st.session_state.data.select_dtypes(include=np.number).columns) |
| | |
| | if time_col and value_col: |
| | result = temporal_analysis.invoke({ |
| | "data_key": "data", |
| | "time_col": time_col, |
| | "value_col": value_col |
| | }) |
| | if "visualization" in result: |
| | st.image(f"data:image/png;base64,{result['visualization']}") |
| | st.json(result) |
| | |
| | elif analysis_type == "Comparative Statistics": |
| | group_col = st.selectbox("Grouping Variable", |
| | st.session_state.data.select_dtypes(include='category').columns) |
| | value_col = st.selectbox("Metric Variable", |
| | st.session_state.data.select_dtypes(include=np.number).columns) |
| | |
| | if group_col and value_col: |
| | result = hypothesis_testing.invoke({ |
| | "data_key": "data", |
| | "group_col": group_col, |
| | "value_col": value_col |
| | }) |
| | st.subheader("Statistical Test Results") |
| | st.json(result) |
| | |
| | elif analysis_type == "Distribution Analysis": |
| | num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() |
| | selected_cols = st.multiselect("Select Variables", num_cols) |
| | if selected_cols: |
| | img_data = visualize_distributions.invoke({ |
| | "data_key": "data", |
| | "columns": selected_cols |
| | }) |
| | st.image(f"data:image/png;base64,{img_data}") |
| | |
| | with research_tab: |
| | research_query = st.text_area("Enter Research Question:", height=150, |
| | placeholder="E.g., 'What factors are most predictive of X outcome?'") |
| | |
| | if st.button("Execute Research"): |
| | with st.spinner("Conducting rigorous analysis..."): |
| | result = st.session_state.researcher.research( |
| | research_query, st.session_state.data |
| | ) |
| | st.markdown("## Research Findings") |
| | st.markdown(result) |
| |
|
| | if __name__ == "__main__": |
| | main() |