yonkoyonks commited on
Commit
6b78fee
·
verified ·
1 Parent(s): f057e48

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +38 -7
utils.py CHANGED
@@ -2,7 +2,29 @@ from huggingface_hub import InferenceClient
2
  import os
3
  import pandas as pd
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
 
 
 
6
  summary = f"Columns: {', '.join(df.columns)}\n\n"
7
  if len(df) > max_rows:
8
  sample = df.sample(max_rows, random_state=42)
@@ -13,23 +35,32 @@ def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
13
  summary += sample.to_string(index=False)
14
  return summary
15
 
 
16
 
17
  def query_agent(df: pd.DataFrame, query: str) -> str:
 
 
 
 
 
 
 
 
18
  query_lower = query.lower()
19
 
20
- # ----------------- Direct Analysis for Most Common -----------------
21
  try:
22
  if "most common" in query_lower or "most frequent" in query_lower:
23
- # Look for multiple columns in query
24
  cols_in_query = [col for col in df.columns if col.lower() in query_lower]
25
-
26
  if len(cols_in_query) == 1:
27
  col = cols_in_query[0]
28
  value = df[col].mode()[0]
29
  return f"The most common value in column '{col}' is '{value}'."
30
-
31
  elif len(cols_in_query) > 1:
32
- # Compute most common combination of values across the columns
33
  combo_series = df[cols_in_query].apply(lambda row: tuple(row), axis=1)
34
  most_common_combo = combo_series.mode()[0]
35
  combo_str = ", ".join(f"{col}={val}" for col, val in zip(cols_in_query, most_common_combo))
@@ -38,7 +69,7 @@ def query_agent(df: pd.DataFrame, query: str) -> str:
38
  except Exception as e:
39
  print("Direct analysis failed:", e)
40
 
41
- # ----------------- Use LLM if direct analysis fails -----------------
42
  data_text = summarize_dataframe(df)
43
  prompt = f"""
44
  You are a data analysis assistant with expertise in statistics and data interpretation.
@@ -54,7 +85,7 @@ Question:
54
  Answer (with explanation):
55
  """
56
 
57
- # Initialize client with explicit provider
58
  client = InferenceClient(
59
  model="google/gemma-2b-it",
60
  provider="hf-inference",
 
2
  import os
3
  import pandas as pd
4
 
5
+ # ------------------- Helper Functions -------------------
6
+
7
+ def split_multi_value_columns(df: pd.DataFrame) -> pd.DataFrame:
8
+ """
9
+ Automatically splits any column that contains multiple comma-separated values
10
+ into separate columns.
11
+ """
12
+ new_df = df.copy()
13
+ for col in df.columns:
14
+ # Check if the first non-null row contains a comma
15
+ sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else ""
16
+ if isinstance(sample, str) and "," in sample:
17
+ # Split the column into multiple columns
18
+ split_cols = df[col].str.split(",", expand=True)
19
+ split_cols = split_cols.rename(columns=lambda i: f"{col}_{i+1}")
20
+ new_df = new_df.drop(columns=[col]).join(split_cols)
21
+ return new_df
22
+
23
+
24
  def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
25
+ """
26
+ Returns a text summary of the dataframe for LLM prompts.
27
+ """
28
  summary = f"Columns: {', '.join(df.columns)}\n\n"
29
  if len(df) > max_rows:
30
  sample = df.sample(max_rows, random_state=42)
 
35
  summary += sample.to_string(index=False)
36
  return summary
37
 
38
+ # ------------------- Main Query Function -------------------
39
 
40
  def query_agent(df: pd.DataFrame, query: str) -> str:
41
+ """
42
+ Analyzes a dataframe to answer queries. Supports:
43
+ - Direct analysis of most common values (single or multiple columns)
44
+ - Fallback to LLM using google/gemma-2b-it
45
+ """
46
+ # Automatically split multi-value columns
47
+ df = split_multi_value_columns(df)
48
+
49
  query_lower = query.lower()
50
 
51
+ # -------- Direct Analysis for Most Common Values --------
52
  try:
53
  if "most common" in query_lower or "most frequent" in query_lower:
54
+ # Find all columns mentioned in the query
55
  cols_in_query = [col for col in df.columns if col.lower() in query_lower]
56
+
57
  if len(cols_in_query) == 1:
58
  col = cols_in_query[0]
59
  value = df[col].mode()[0]
60
  return f"The most common value in column '{col}' is '{value}'."
61
+
62
  elif len(cols_in_query) > 1:
63
+ # Most common combination across multiple columns
64
  combo_series = df[cols_in_query].apply(lambda row: tuple(row), axis=1)
65
  most_common_combo = combo_series.mode()[0]
66
  combo_str = ", ".join(f"{col}={val}" for col, val in zip(cols_in_query, most_common_combo))
 
69
  except Exception as e:
70
  print("Direct analysis failed:", e)
71
 
72
+ # -------- Use LLM Fallback if Direct Analysis Fails --------
73
  data_text = summarize_dataframe(df)
74
  prompt = f"""
75
  You are a data analysis assistant with expertise in statistics and data interpretation.
 
85
  Answer (with explanation):
86
  """
87
 
88
+ # Initialize Hugging Face InferenceClient with explicit provider
89
  client = InferenceClient(
90
  model="google/gemma-2b-it",
91
  provider="hf-inference",