pavanmutha commited on
Commit
dd7e543
·
verified ·
1 Parent(s): f4d8cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -180
app.py CHANGED
@@ -7,38 +7,26 @@ import wandb
7
  import time
8
  import psutil
9
  import optuna
 
10
 
11
- from huggingface_hub import login
12
-
13
- # Add this before model initialization
14
  hf_token = os.getenv("HF_TOKEN")
15
  login(token=hf_token, add_to_git_credential=True)
16
 
17
- # Then create model with explicit token
18
- model = HfApiModel(
19
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
20
- token=hf_token # Pass token explicitly
21
- )
22
 
23
- # Add this formatting function
24
  def format_analysis_report(raw_output, visuals):
25
  try:
26
- # Check if raw_output is already a dictionary
27
- if isinstance(raw_output, dict):
28
- analysis_dict = raw_output
29
- else:
30
- # Attempt to convert string output to dictionary
31
- analysis_dict = ast.literal_eval(str(raw_output))
32
 
33
  report = f"""
34
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
35
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
36
-
37
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
38
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
39
  {format_observations(analysis_dict.get('observations', {}))}
40
  </div>
41
-
42
  <div style="margin-top: 30px;">
43
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
44
  {format_insights(analysis_dict.get('insights', {}), visuals)}
@@ -50,212 +38,92 @@ def format_analysis_report(raw_output, visuals):
50
  return raw_output, visuals
51
 
52
  def format_observations(observations):
53
- items = []
54
- for key, value in observations.items():
55
- if 'proportions' in key:
56
- items.append(f"""
57
- <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
58
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
59
- <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
60
- </div>
61
- """)
62
- return '\n'.join(items)
63
 
64
  def format_insights(insights, visuals):
65
- items = []
66
- for idx, (key, insight) in enumerate(insights.items()):
67
- img_tag = ""
68
- if idx < len(visuals):
69
- img_tag = f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">'
70
-
71
- items.append(f"""
72
  <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
73
  <div style="display: flex; align-items: center; gap: 10px;">
74
  <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
75
  <p style="margin: 0; font-size: 16px;">{insight}</p>
76
  </div>
77
- {img_tag}
78
  </div>
79
- """)
80
- return '\n'.join(items)
81
 
82
  def analyze_data(csv_file, additional_notes=""):
83
-
84
-
85
- # Start timing
86
  start_time = time.time()
87
-
88
- # Get initial memory usage
89
  process = psutil.Process(os.getpid())
90
- initial_memory = process.memory_info().rss / 1024 ** 2 # Convert to MB
91
 
92
- # Clear previous figures
93
  if os.path.exists('./figures'):
94
  shutil.rmtree('./figures')
95
  os.makedirs('./figures', exist_ok=True)
96
-
97
- # 🚨 Initialize W&B run
98
  wandb.login(key=os.environ.get('WANDB_API_KEY'))
99
- run = wandb.init(
100
- project="huggingface-data-analysis",
101
- config={
102
- "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
103
- "additional_notes": additional_notes,
104
- "source_file": csv_file.name if csv_file else None
105
- }
106
- )
107
-
108
- # Initialize model and agent
109
- model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1") # Best overall
110
- agent = CodeAgent(
111
- tools=[],
112
- model=model,
113
- additional_authorized_imports=[
114
- "numpy",
115
- "pandas",
116
- "matplotlib.pyplot",
117
- "seaborn"
118
- ],
119
- )
120
- # Run analysis
121
- analysis_result = agent.run(
122
- """You are an expert data analyst. Perform comprehensive analysis including:
123
  1. Basic statistics and data quality checks
124
  2. 3 insightful analytical questions about relationships in the data
125
  3. Visualization of key patterns and correlations
126
  4. Actionable real-world insights derived from findings
127
-
128
  Generate publication-quality visualizations and save to './figures/'
129
- """,
130
- additional_args={
131
- "additional_notes": additional_notes,
132
- "source_file": csv_file
133
- }
134
- )
 
 
 
 
 
 
 
135
 
136
- analysis_result = agent.run(
137
- f"""Perform comprehensive analysis with:
138
- - Learning Rate: {learning_rate}
139
- - Batch Size: {batch_size}
140
- - Epochs: {num_epochs}
141
- """,
142
- additional_args={}
143
- )
144
  def objective(trial):
145
  learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 5e-3)
146
  batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
147
  num_epochs = trial.suggest_int("num_epochs", 1, 5)
148
-
 
149
  def tune_hyperparameters(n_trials: int):
150
  study = optuna.create_study(direction="minimize")
151
  study.optimize(objective, n_trials=n_trials)
152
-
153
  return f"Best Hyperparameters: {study.best_params}"
154
 
155
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
  gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
157
-
158
- with gr.Row():
159
  with gr.Column():
160
  file_input = gr.File(label="Upload CSV Dataset", type="filepath")
161
  notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
162
  analyze_btn = gr.Button("Analyze", variant="primary")
163
  optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
164
  tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
165
-
166
  with gr.Column():
167
  analysis_output = gr.Markdown("### Analysis results will appear here...")
168
  optuna_output = gr.Textbox(label="Best Hyperparameters")
169
  gallery = gr.Gallery(label="Data Visualizations", columns=2)
170
-
171
- analyze_btn.click(
172
- fn=analyze_data,
173
- inputs=[file_input, notes_input],
174
- outputs=[analysis_output, gallery]
175
- )
176
-
177
- tune_btn.click(
178
- fn=tune_hyperparameters,
179
- inputs=[optuna_trials],
180
- outputs=[optuna_output]
181
- )
182
-
183
- demo.launch(debug=True)
184
-
185
-
186
- # Assume we minimize some loss value (replace with actual metric)
187
- loss = analysis_result.get("loss", 0.1) # Mock value
188
- return loss # Optuna minimizes the loss
189
- # Measure execution time and memory usage
190
- execution_time = time.time() - start_time
191
- final_memory = process.memory_info().rss / 1024 ** 2 # Convert to MB
192
- memory_usage = final_memory - initial_memory # Calculate memory consumed
193
-
194
- # 🚨 Log Performance Metrics
195
- wandb.log({
196
- "execution_time_sec": execution_time,
197
- "memory_usage_mb": memory_usage
198
- })
199
-
200
- # Log analysis results to W&B
201
- if isinstance(analysis_result, dict):
202
- wandb.log({
203
- "observations": analysis_result.get('observations', {}),
204
- "insights": analysis_result.get('insights', {}),
205
- "num_visuals": len(os.listdir('./figures')) # Ensure visuals are counted
206
- })
207
-
208
- # Log generated visualizations
209
- visuals = [os.path.join('./figures', f) for f in os.listdir('./figures')
210
- if f.endswith(('.png', '.jpg', '.jpeg'))]
211
- for viz in visuals:
212
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
213
-
214
-
215
- # 🚨 Log visualizations to W&B
216
- if visuals:
217
- try:
218
- for viz in visuals:
219
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
220
- except Exception as e:
221
- print(f"Error logging visuals: {e}")
222
-
223
- # 🚨 Log CSV as artifact
224
- if csv_file:
225
- try:
226
- artifact = wandb.Artifact(
227
- name="source_data",
228
- type="dataset",
229
- description="Uploaded CSV for analysis"
230
- )
231
- artifact.add_file(csv_file.name)
232
- wandb.log_artifact(artifact)
233
- except Exception as e:
234
- print(f"Error logging CSV: {e}")
235
-
236
- # 🚨 Finish W&B run
237
- run.finish()
238
 
239
- return format_analysis_report(analysis_result, visuals)
 
240
 
241
- # Create Gradio interface
242
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
243
- gr.Markdown("## 📊 AI Data Analysis Agent")
244
-
245
- with gr.Row():
246
- with gr.Column():
247
- file_input = gr.File(label="Upload CSV Dataset", type="filepath")
248
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
249
- analyze_btn = gr.Button("Analyze", variant="primary")
250
-
251
- with gr.Column():
252
- analysis_output = gr.Markdown("### Analysis results will appear here...")
253
- gallery = gr.Gallery(label="Data Visualizations", columns=2)
254
-
255
- analyze_btn.click(
256
- fn=analyze_data,
257
- inputs=[file_input, notes_input],
258
- outputs=[analysis_output, gallery]
259
- )
260
-
261
- demo.launch(debug=True)
 
7
  import time
8
  import psutil
9
  import optuna
10
+ import ast
11
 
12
+ # Authenticate Hugging Face
 
 
13
  hf_token = os.getenv("HF_TOKEN")
14
  login(token=hf_token, add_to_git_credential=True)
15
 
16
+ # Initialize Model
17
+ model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
 
 
 
18
 
 
19
  def format_analysis_report(raw_output, visuals):
20
  try:
21
+ analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
 
 
 
 
 
22
 
23
  report = f"""
24
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
25
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
 
26
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
27
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
28
  {format_observations(analysis_dict.get('observations', {}))}
29
  </div>
 
30
  <div style="margin-top: 30px;">
31
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
32
  {format_insights(analysis_dict.get('insights', {}), visuals)}
 
38
  return raw_output, visuals
39
 
40
  def format_observations(observations):
41
+ return '\n'.join([
42
+ f"""
43
+ <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
44
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
45
+ <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
46
+ </div>
47
+ """ for key, value in observations.items() if 'proportions' in key
48
+ ])
 
 
49
 
50
  def format_insights(insights, visuals):
51
+ return '\n'.join([
52
+ f"""
 
 
 
 
 
53
  <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
54
  <div style="display: flex; align-items: center; gap: 10px;">
55
  <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
56
  <p style="margin: 0; font-size: 16px;">{insight}</p>
57
  </div>
58
+ {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
59
  </div>
60
+ """ for idx, (key, insight) in enumerate(insights.items())
61
+ ])
62
 
63
  def analyze_data(csv_file, additional_notes=""):
 
 
 
64
  start_time = time.time()
 
 
65
  process = psutil.Process(os.getpid())
66
+ initial_memory = process.memory_info().rss / 1024 ** 2
67
 
 
68
  if os.path.exists('./figures'):
69
  shutil.rmtree('./figures')
70
  os.makedirs('./figures', exist_ok=True)
71
+
 
72
  wandb.login(key=os.environ.get('WANDB_API_KEY'))
73
+ run = wandb.init(project="huggingface-data-analysis", config={
74
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
75
+ "additional_notes": additional_notes,
76
+ "source_file": csv_file.name if csv_file else None
77
+ })
78
+
79
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
80
+ analysis_result = agent.run("""
81
+ You are an expert data analyst. Perform comprehensive analysis including:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  1. Basic statistics and data quality checks
83
  2. 3 insightful analytical questions about relationships in the data
84
  3. Visualization of key patterns and correlations
85
  4. Actionable real-world insights derived from findings
 
86
  Generate publication-quality visualizations and save to './figures/'
87
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
88
+
89
+ execution_time = time.time() - start_time
90
+ final_memory = process.memory_info().rss / 1024 ** 2
91
+ memory_usage = final_memory - initial_memory
92
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
93
+
94
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
95
+ for viz in visuals:
96
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
97
+
98
+ run.finish()
99
+ return format_analysis_report(analysis_result, visuals)
100
 
 
 
 
 
 
 
 
 
101
  def objective(trial):
102
  learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 5e-3)
103
  batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
104
  num_epochs = trial.suggest_int("num_epochs", 1, 5)
105
+ return learning_rate * batch_size * num_epochs
106
+
107
  def tune_hyperparameters(n_trials: int):
108
  study = optuna.create_study(direction="minimize")
109
  study.optimize(objective, n_trials=n_trials)
 
110
  return f"Best Hyperparameters: {study.best_params}"
111
 
112
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
  gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
114
+ with gr.Row():
 
115
  with gr.Column():
116
  file_input = gr.File(label="Upload CSV Dataset", type="filepath")
117
  notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
118
  analyze_btn = gr.Button("Analyze", variant="primary")
119
  optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
120
  tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
 
121
  with gr.Column():
122
  analysis_output = gr.Markdown("### Analysis results will appear here...")
123
  optuna_output = gr.Textbox(label="Best Hyperparameters")
124
  gallery = gr.Gallery(label="Data Visualizations", columns=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ analyze_btn.click(fn=analyze_data, inputs=[file_input, notes_input], outputs=[analysis_output, gallery])
127
+ tune_btn.click(fn=tune_hyperparameters, inputs=[optuna_trials], outputs=[optuna_output])
128
 
129
+ demo.launch(debug=True)