Update app.py
Browse files
app.py
CHANGED
|
@@ -31,8 +31,10 @@ login(token=hf_token)
|
|
| 31 |
# SmolAgent initialization
|
| 32 |
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
| 33 |
|
|
|
|
| 34 |
df_global = None
|
| 35 |
target_column_global = None
|
|
|
|
| 36 |
|
| 37 |
def clean_data(df):
|
| 38 |
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
|
|
@@ -42,24 +44,44 @@ def clean_data(df):
|
|
| 42 |
df = df.fillna(df.mean(numeric_only=True))
|
| 43 |
return df
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def upload_file(file):
|
| 46 |
-
global df_global
|
| 47 |
if file is None:
|
| 48 |
return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
|
|
|
|
| 49 |
ext = os.path.splitext(file.name)[-1]
|
| 50 |
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
|
| 51 |
df = clean_data(df)
|
| 52 |
df_global = df
|
|
|
|
| 53 |
return df.head(), gr.update(choices=df.columns.tolist())
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
def set_target_column(col_name):
|
| 58 |
global target_column_global
|
| 59 |
target_column_global = col_name
|
| 60 |
return f"✅ Target column set to: {col_name}"
|
| 61 |
|
| 62 |
|
|
|
|
| 63 |
def format_analysis_report(raw_output, visuals):
|
| 64 |
import json
|
| 65 |
|
|
@@ -195,7 +217,7 @@ def extract_json_from_codeagent_output(raw_output):
|
|
| 195 |
|
| 196 |
|
| 197 |
|
| 198 |
-
def analyze_data(csv_file, additional_notes=""):
|
| 199 |
try:
|
| 200 |
start_time = time.time()
|
| 201 |
process = psutil.Process(os.getpid())
|
|
@@ -211,7 +233,7 @@ def analyze_data(csv_file, additional_notes=""):
|
|
| 211 |
run = wandb.init(project="huggingface-data-analysis", config={
|
| 212 |
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
| 213 |
"additional_notes": additional_notes,
|
| 214 |
-
"source_file": csv_file.name if csv_file else
|
| 215 |
})
|
| 216 |
|
| 217 |
# Initialize Code Agent
|
|
@@ -221,8 +243,23 @@ def analyze_data(csv_file, additional_notes=""):
|
|
| 221 |
additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]
|
| 222 |
)
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
You are a helpful data analysis agent. Please follow these very strict instructions and formatting:
|
| 227 |
|
| 228 |
1. Load the data from the provided `source_file`.
|
|
@@ -259,10 +296,11 @@ Be concise and avoid any narrative outside this final dictionary.
|
|
| 259 |
Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allowed)
|
| 260 |
"""
|
| 261 |
|
| 262 |
-
# Run
|
| 263 |
-
analysis_result = agent.run(
|
| 264 |
"additional_notes": additional_notes,
|
| 265 |
-
"source_file": csv_file.name if csv_file else None
|
|
|
|
| 266 |
})
|
| 267 |
|
| 268 |
# Performance metrics
|
|
@@ -281,25 +319,19 @@ Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allo
|
|
| 281 |
if f.endswith(('.png', '.jpg', '.jpeg'))
|
| 282 |
])
|
| 283 |
|
| 284 |
-
# Log visuals to WandB
|
| 285 |
for viz in visuals:
|
| 286 |
wandb.log({os.path.basename(viz): wandb.Image(viz)})
|
| 287 |
|
| 288 |
run.finish()
|
| 289 |
-
|
| 290 |
print("DEBUG - Raw agent output:", analysis_result[:500] + "...")
|
| 291 |
-
print("Columns in data:", df_global.columns.tolist())
|
| 292 |
-
print("Data types:", df_global.dtypes)
|
| 293 |
with open("agent_output.txt", "w") as f:
|
| 294 |
f.write(str(analysis_result))
|
| 295 |
-
# Parse the agent output
|
| 296 |
-
parsed_result = extract_json_from_codeagent_output(analysis_result)
|
| 297 |
-
print(f"DEBUG - Parsed result: {parsed_result}") # Debug output
|
| 298 |
|
|
|
|
| 299 |
if parsed_result:
|
| 300 |
return format_analysis_report(parsed_result, visuals)
|
| 301 |
else:
|
| 302 |
-
# Fallback to showing raw output if parsing fails
|
| 303 |
error_msg = f"Failed to parse agent output. Showing raw response:\n{str(analysis_result)[:2000]}"
|
| 304 |
print(error_msg)
|
| 305 |
return f"<pre>{error_msg}</pre>", visuals
|
|
@@ -309,7 +341,7 @@ Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allo
|
|
| 309 |
print(error_msg)
|
| 310 |
return f"<pre>{error_msg}</pre>", []
|
| 311 |
|
| 312 |
-
|
| 313 |
def compare_models():
|
| 314 |
import seaborn as sns
|
| 315 |
from sklearn.model_selection import cross_val_predict
|
|
|
|
| 31 |
# SmolAgent initialization
|
| 32 |
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
| 33 |
|
| 34 |
+
# Globals
|
| 35 |
df_global = None
|
| 36 |
target_column_global = None
|
| 37 |
+
data_summary_global = None # ⬅️ Added for summarized data
|
| 38 |
|
| 39 |
def clean_data(df):
|
| 40 |
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
|
|
|
|
| 44 |
df = df.fillna(df.mean(numeric_only=True))
|
| 45 |
return df
|
| 46 |
|
| 47 |
+
def summarize_data(df: pd.DataFrame, max_cols: int = 10, max_rows: int = 5) -> str:
|
| 48 |
+
summary = []
|
| 49 |
+
summary.append(f"Dataset shape: {df.shape}")
|
| 50 |
+
summary.append("\nColumn types:\n" + str(df.dtypes))
|
| 51 |
+
|
| 52 |
+
num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 53 |
+
cat_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()
|
| 54 |
+
|
| 55 |
+
summary.append("\nMissing values:\n" + str(df.isnull().sum()))
|
| 56 |
+
|
| 57 |
+
if num_cols:
|
| 58 |
+
summary.append("\nNumerical summary:\n" + str(df[num_cols].describe().T.head(max_rows)))
|
| 59 |
+
if cat_cols:
|
| 60 |
+
summary.append("\nCategorical value counts (top categories):")
|
| 61 |
+
for col in cat_cols[:max_cols]:
|
| 62 |
+
summary.append(f"\nColumn: {col}\n{df[col].value_counts().head(max_rows)}")
|
| 63 |
+
|
| 64 |
+
return "\n".join(summary)
|
| 65 |
+
|
| 66 |
def upload_file(file):
|
| 67 |
+
global df_global, data_summary_global
|
| 68 |
if file is None:
|
| 69 |
return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
|
| 70 |
+
|
| 71 |
ext = os.path.splitext(file.name)[-1]
|
| 72 |
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
|
| 73 |
df = clean_data(df)
|
| 74 |
df_global = df
|
| 75 |
+
data_summary_global = summarize_data(df) # ⬅️ Summarize here
|
| 76 |
return df.head(), gr.update(choices=df.columns.tolist())
|
| 77 |
|
|
|
|
|
|
|
| 78 |
def set_target_column(col_name):
|
| 79 |
global target_column_global
|
| 80 |
target_column_global = col_name
|
| 81 |
return f"✅ Target column set to: {col_name}"
|
| 82 |
|
| 83 |
|
| 84 |
+
|
| 85 |
def format_analysis_report(raw_output, visuals):
|
| 86 |
import json
|
| 87 |
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
|
| 220 |
+
def analyze_data(csv_file=None, additional_notes="", use_summary=True):
|
| 221 |
try:
|
| 222 |
start_time = time.time()
|
| 223 |
process = psutil.Process(os.getpid())
|
|
|
|
| 233 |
run = wandb.init(project="huggingface-data-analysis", config={
|
| 234 |
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
| 235 |
"additional_notes": additional_notes,
|
| 236 |
+
"source_file": csv_file.name if csv_file else "summarized_input"
|
| 237 |
})
|
| 238 |
|
| 239 |
# Initialize Code Agent
|
|
|
|
| 243 |
additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]
|
| 244 |
)
|
| 245 |
|
| 246 |
+
# Choose prompt content
|
| 247 |
+
if use_summary and data_summary_global:
|
| 248 |
+
input_data = data_summary_global
|
| 249 |
+
data_instruction = """
|
| 250 |
+
You are analyzing summarized dataset information from a CSV file. Your job is to:
|
| 251 |
+
|
| 252 |
+
1. Interpret the summary content as if it was produced from a real dataset.
|
| 253 |
+
2. Derive at least 5 high-level insights based on column types, distributions, missing values, etc.
|
| 254 |
+
3. Imagine or mock visualizations and describe what they would show. Use synthetic data simulation with numpy/pandas if needed.
|
| 255 |
+
4. Save plots to './figures/' using matplotlib or seaborn.
|
| 256 |
+
|
| 257 |
+
Always respond in the structured dictionary format below.
|
| 258 |
+
"""
|
| 259 |
+
else:
|
| 260 |
+
# Fall back to full file input
|
| 261 |
+
input_data = None # You load file within the agent
|
| 262 |
+
data_instruction = """
|
| 263 |
You are a helpful data analysis agent. Please follow these very strict instructions and formatting:
|
| 264 |
|
| 265 |
1. Load the data from the provided `source_file`.
|
|
|
|
| 296 |
Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allowed)
|
| 297 |
"""
|
| 298 |
|
| 299 |
+
# Run agent with either summarized content or CSV
|
| 300 |
+
analysis_result = agent.run(data_instruction, additional_args={
|
| 301 |
"additional_notes": additional_notes,
|
| 302 |
+
"source_file": csv_file.name if csv_file and not use_summary else None,
|
| 303 |
+
"data_summary": input_data if use_summary else None
|
| 304 |
})
|
| 305 |
|
| 306 |
# Performance metrics
|
|
|
|
| 319 |
if f.endswith(('.png', '.jpg', '.jpeg'))
|
| 320 |
])
|
| 321 |
|
|
|
|
| 322 |
for viz in visuals:
|
| 323 |
wandb.log({os.path.basename(viz): wandb.Image(viz)})
|
| 324 |
|
| 325 |
run.finish()
|
| 326 |
+
|
| 327 |
print("DEBUG - Raw agent output:", analysis_result[:500] + "...")
|
|
|
|
|
|
|
| 328 |
with open("agent_output.txt", "w") as f:
|
| 329 |
f.write(str(analysis_result))
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
+
parsed_result = extract_json_from_codeagent_output(analysis_result)
|
| 332 |
if parsed_result:
|
| 333 |
return format_analysis_report(parsed_result, visuals)
|
| 334 |
else:
|
|
|
|
| 335 |
error_msg = f"Failed to parse agent output. Showing raw response:\n{str(analysis_result)[:2000]}"
|
| 336 |
print(error_msg)
|
| 337 |
return f"<pre>{error_msg}</pre>", visuals
|
|
|
|
| 341 |
print(error_msg)
|
| 342 |
return f"<pre>{error_msg}</pre>", []
|
| 343 |
|
| 344 |
+
|
| 345 |
def compare_models():
|
| 346 |
import seaborn as sns
|
| 347 |
from sklearn.model_selection import cross_val_predict
|