pavanmutha commited on
Commit
af78f21
·
verified ·
1 Parent(s): 4b4c2f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -82
app.py CHANGED
@@ -43,88 +43,109 @@ def upload_file(file):
43
  df_global = df
44
  return df.head()
45
 
46
- import textwrap
47
-
48
- additional_notes = "Please note: Perform a comprehensive analysis including visualizations and insights."
49
-
50
-
51
- # Initialize the agent
52
- agent = CodeAgent(
53
- tools=[],
54
- model=model,
55
- additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "os", "json"]
56
- )
57
-
58
-
59
-
60
- # Gradio Gallery and visualization output
61
- visual_output = gr.Gallery(label="Generated Visualizations", columns=3, height=600, object_fit="contain")
62
-
63
- def run_agent(_):
64
- import os
65
- from PIL import Image
66
-
67
- os.makedirs("figures", exist_ok=True) # Add this just before loading images
68
-
69
-
70
- if df_global is None:
71
- return "Please upload a file first.", []
72
-
73
- # Save the dataset temporarily
74
- from tempfile import NamedTemporaryFile
75
- temp_file = NamedTemporaryFile(delete=False, suffix=".csv")
76
- df_global.to_csv(temp_file.name, index=False)
77
- temp_file.close()
78
-
79
- # Prompt for the agent
80
- prompt = """
81
- You are an expert data analyst.
82
- 1. Load the provided dataset using: df = pd.read_csv(source_file)
83
- 2. Automatically detect numeric and categorical columns.
84
- 3. Perform:
85
- - Basic statistics
86
- - Null/duplicate checks
87
- - Correlation analysis
88
- - 3+ visualizations
89
- 4. Extract 3+ bullet-point insights.
90
- 5. Before saving visualizations, run:
91
- import os; os.makedirs("figures", exist_ok=True)
92
- Then save all figures using plt.savefig("./figures/...")
93
- 6. Return a JSON with:
94
- - 'insights': list of insights
95
- - 'figures': list of figure file paths
96
- """
97
-
98
- result = agent.run(prompt, additional_args={"source_file": temp_file.name})
99
-
100
- # Parse and process output
101
- insights = "No insights returned."
102
- images = []
103
-
104
- if isinstance(result, str):
105
- try:
106
- result = json.loads(result)
107
- except Exception:
108
- return "Agent returned invalid JSON.", []
109
-
110
- if isinstance(result, dict):
111
- raw_insights = result.get("insights", [])
112
- insights = "\n".join(raw_insights) if isinstance(raw_insights, list) else str(raw_insights)
113
-
114
- image_paths = result.get("figures", [])
115
- print("🔍 Image paths received:", image_paths)
116
-
117
- for path in image_paths:
118
- if os.path.exists(path):
119
- try:
120
- images.append(Image.open(path))
121
- except Exception as e:
122
- print(f"⚠️ Error loading {path}: {e}")
123
- else:
124
- print(f"❌ File not found: {path}")
125
-
126
- return insights, images
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
 
 
43
  df_global = df
44
  return df.head()
45
 
46
+ def format_analysis_report(raw_output, visuals):
47
+ try:
48
+ if isinstance(raw_output, dict):
49
+ analysis_dict = raw_output
50
+ else:
51
+ try:
52
+ analysis_dict = ast.literal_eval(str(raw_output))
53
+ except (SyntaxError, ValueError) as e:
54
+ print(f"Error parsing CodeAgent output: {e}")
55
+ return str(raw_output), visuals # Return raw output as string
56
+
57
+ report = f"""
58
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
59
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
60
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
61
+ <h2 style="color: #2B547E;">🔍 Key Observations</h2>
62
+ {format_observations(analysis_dict.get('observations', {}))}
63
+ </div>
64
+ <div style="margin-top: 30px;">
65
+ <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
66
+ {format_insights(analysis_dict.get('insights', {}), visuals)}
67
+ </div>
68
+ </div>
69
+ """
70
+ return report, visuals
71
+ except Exception as e:
72
+ print(f"Error in format_analysis_report: {e}")
73
+ return str(raw_output), visuals
74
+
75
+ def format_observations(observations):
76
+ return '\n'.join([
77
+ f"""
78
+ <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
79
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
80
+ <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
81
+ </div>
82
+ """ for key, value in observations.items() if 'proportions' in key
83
+ ])
84
+
85
+ def format_insights(insights, visuals):
86
+ return '\n'.join([
87
+ f"""
88
+ <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
89
+ <div style="display: flex; align-items: center; gap: 10px;">
90
+ <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
91
+ <p style="margin: 0; font-size: 16px;">{insight}</p>
92
+ </div>
93
+ {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
94
+ </div>
95
+ """ for idx, (key, insight) in enumerate(insights.items())
96
+ ])
97
+
98
+ def analyze_data(csv_file, additional_notes=""):
99
+ start_time = time.time()
100
+ process = psutil.Process(os.getpid())
101
+ initial_memory = process.memory_info().rss / 1024 ** 2
102
+
103
+ if os.path.exists('./figures'):
104
+ shutil.rmtree('./figures')
105
+ os.makedirs('./figures', exist_ok=True)
106
+
107
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
108
+ run = wandb.init(project="huggingface-data-analysis", config={
109
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
110
+ "additional_notes": additional_notes,
111
+ "source_file": csv_file.name if csv_file else None
112
+ })
113
+
114
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn"])
115
+ analysis_result = agent.run("""
116
+ You are an expert data analyst. Perform comprehensive analysis including:
117
+ 1. Basic statistics and data quality checks
118
+ 2. 3 insightful analytical questions about relationships in the data
119
+ 3. Visualization of key patterns and correlations
120
+ 4. Actionable real-world insights derived from findings.
121
+ Generate publication-quality visualizations and save to './figures/'.
122
+ Return the analysis results as a python dictionary that can be parsed by ast.literal_eval().
123
+ The dictionary should have the following structure:
124
+ {
125
+ 'observations': {
126
+ 'observation_1_key': 'observation_1_value',
127
+ 'observation_2_key': 'observation_2_value',
128
+ ...
129
+ },
130
+ 'insights': {
131
+ 'insight_1_key': 'insight_1_value',
132
+ 'insight_2_key': 'insight_2_value',
133
+ ...
134
+ }
135
+ }
136
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
137
+
138
+ execution_time = time.time() - start_time
139
+ final_memory = process.memory_info().rss / 1024 ** 2
140
+ memory_usage = final_memory - initial_memory
141
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
142
+
143
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
144
+ for viz in visuals:
145
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
146
+
147
+ run.finish()
148
+ return format_analysis_report(analysis_result, visuals)
149
 
150
 
151