yonkoyonks commited on
Commit
f057e48
·
verified ·
1 Parent(s): 7a5fdf7

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +29 -19
utils.py CHANGED
@@ -16,15 +16,29 @@ def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
16
 
17
  def query_agent(df: pd.DataFrame, query: str) -> str:
18
  query_lower = query.lower()
 
 
19
  try:
20
  if "most common" in query_lower or "most frequent" in query_lower:
21
- for col in df.columns:
22
- if col.lower() in query_lower:
23
- value = df[col].mode()[0]
24
- return f"The most common value in column '{col}' is '{value}'."
 
 
 
 
 
 
 
 
 
 
 
25
  except Exception as e:
26
  print("Direct analysis failed:", e)
27
 
 
28
  data_text = summarize_dataframe(df)
29
  prompt = f"""
30
  You are a data analysis assistant with expertise in statistics and data interpretation.
@@ -40,33 +54,29 @@ Question:
40
  Answer (with explanation):
41
  """
42
 
43
- # Explicitly specify provider (avoids StopIteration)
44
  client = InferenceClient(
 
45
  provider="hf-inference",
46
- token=os.environ.get("HUGGINGFACE_API_KEY")
47
  )
48
 
49
  try:
50
  response = client.text_generation(
51
  prompt,
52
- model="google/gemma-2b-it",
53
  max_new_tokens=1024,
54
  temperature=0.7,
55
  )
56
- except TypeError:
57
- response = client.text_generation(
58
- model="google/gemma-2b-it",
59
- inputs=prompt,
60
- parameters={"max_new_tokens": 1024, "temperature": 0.7},
61
- )
62
 
 
63
  if isinstance(response, str):
64
- answer = response
65
  elif isinstance(response, dict) and "generated_text" in response:
66
- answer = response["generated_text"]
67
  elif isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
68
- answer = response[0]["generated_text"]
69
  else:
70
- answer = str(response)
71
-
72
- return answer
 
16
 
17
  def query_agent(df: pd.DataFrame, query: str) -> str:
18
  query_lower = query.lower()
19
+
20
+ # ----------------- Direct Analysis for Most Common -----------------
21
  try:
22
  if "most common" in query_lower or "most frequent" in query_lower:
23
+ # Look for multiple columns in query
24
+ cols_in_query = [col for col in df.columns if col.lower() in query_lower]
25
+
26
+ if len(cols_in_query) == 1:
27
+ col = cols_in_query[0]
28
+ value = df[col].mode()[0]
29
+ return f"The most common value in column '{col}' is '{value}'."
30
+
31
+ elif len(cols_in_query) > 1:
32
+ # Compute most common combination of values across the columns
33
+ combo_series = df[cols_in_query].apply(lambda row: tuple(row), axis=1)
34
+ most_common_combo = combo_series.mode()[0]
35
+ combo_str = ", ".join(f"{col}={val}" for col, val in zip(cols_in_query, most_common_combo))
36
+ return f"The most common combination of values is: {combo_str}"
37
+
38
  except Exception as e:
39
  print("Direct analysis failed:", e)
40
 
41
+ # ----------------- Use LLM if direct analysis fails -----------------
42
  data_text = summarize_dataframe(df)
43
  prompt = f"""
44
  You are a data analysis assistant with expertise in statistics and data interpretation.
 
54
  Answer (with explanation):
55
  """
56
 
57
+ # Initialize client with explicit provider
58
  client = InferenceClient(
59
+ model="google/gemma-2b-it",
60
  provider="hf-inference",
61
+ token=os.environ.get("HUGGINGFACE_API_KEY"),
62
  )
63
 
64
  try:
65
  response = client.text_generation(
66
  prompt,
 
67
  max_new_tokens=1024,
68
  temperature=0.7,
69
  )
70
+ except Exception as e:
71
+ print("Model call failed:", e)
72
+ return "⚠️ Sorry, the model could not generate an answer. Please try again."
 
 
 
73
 
74
+ # Extract text safely
75
  if isinstance(response, str):
76
+ return response
77
  elif isinstance(response, dict) and "generated_text" in response:
78
+ return response["generated_text"]
79
  elif isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
80
+ return response[0]["generated_text"]
81
  else:
82
+ return str(response)