Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -16,15 +16,29 @@ def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
|
|
| 16 |
|
| 17 |
def query_agent(df: pd.DataFrame, query: str) -> str:
|
| 18 |
query_lower = query.lower()
|
|
|
|
|
|
|
| 19 |
try:
|
| 20 |
if "most common" in query_lower or "most frequent" in query_lower:
|
| 21 |
-
for
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
except Exception as e:
|
| 26 |
print("Direct analysis failed:", e)
|
| 27 |
|
|
|
|
| 28 |
data_text = summarize_dataframe(df)
|
| 29 |
prompt = f"""
|
| 30 |
You are a data analysis assistant with expertise in statistics and data interpretation.
|
|
@@ -40,33 +54,29 @@ Question:
|
|
| 40 |
Answer (with explanation):
|
| 41 |
"""
|
| 42 |
|
| 43 |
-
#
|
| 44 |
client = InferenceClient(
|
|
|
|
| 45 |
provider="hf-inference",
|
| 46 |
-
token=os.environ.get("HUGGINGFACE_API_KEY")
|
| 47 |
)
|
| 48 |
|
| 49 |
try:
|
| 50 |
response = client.text_generation(
|
| 51 |
prompt,
|
| 52 |
-
model="google/gemma-2b-it",
|
| 53 |
max_new_tokens=1024,
|
| 54 |
temperature=0.7,
|
| 55 |
)
|
| 56 |
-
except
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
inputs=prompt,
|
| 60 |
-
parameters={"max_new_tokens": 1024, "temperature": 0.7},
|
| 61 |
-
)
|
| 62 |
|
|
|
|
| 63 |
if isinstance(response, str):
|
| 64 |
-
|
| 65 |
elif isinstance(response, dict) and "generated_text" in response:
|
| 66 |
-
|
| 67 |
elif isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
|
| 68 |
-
|
| 69 |
else:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return answer
|
|
|
|
| 16 |
|
| 17 |
def query_agent(df: pd.DataFrame, query: str) -> str:
|
| 18 |
query_lower = query.lower()
|
| 19 |
+
|
| 20 |
+
# ----------------- Direct Analysis for Most Common -----------------
|
| 21 |
try:
|
| 22 |
if "most common" in query_lower or "most frequent" in query_lower:
|
| 23 |
+
# Look for multiple columns in query
|
| 24 |
+
cols_in_query = [col for col in df.columns if col.lower() in query_lower]
|
| 25 |
+
|
| 26 |
+
if len(cols_in_query) == 1:
|
| 27 |
+
col = cols_in_query[0]
|
| 28 |
+
value = df[col].mode()[0]
|
| 29 |
+
return f"The most common value in column '{col}' is '{value}'."
|
| 30 |
+
|
| 31 |
+
elif len(cols_in_query) > 1:
|
| 32 |
+
# Compute most common combination of values across the columns
|
| 33 |
+
combo_series = df[cols_in_query].apply(lambda row: tuple(row), axis=1)
|
| 34 |
+
most_common_combo = combo_series.mode()[0]
|
| 35 |
+
combo_str = ", ".join(f"{col}={val}" for col, val in zip(cols_in_query, most_common_combo))
|
| 36 |
+
return f"The most common combination of values is: {combo_str}"
|
| 37 |
+
|
| 38 |
except Exception as e:
|
| 39 |
print("Direct analysis failed:", e)
|
| 40 |
|
| 41 |
+
# ----------------- Use LLM if direct analysis fails -----------------
|
| 42 |
data_text = summarize_dataframe(df)
|
| 43 |
prompt = f"""
|
| 44 |
You are a data analysis assistant with expertise in statistics and data interpretation.
|
|
|
|
| 54 |
Answer (with explanation):
|
| 55 |
"""
|
| 56 |
|
| 57 |
+
# Initialize client with explicit provider
|
| 58 |
client = InferenceClient(
|
| 59 |
+
model="google/gemma-2b-it",
|
| 60 |
provider="hf-inference",
|
| 61 |
+
token=os.environ.get("HUGGINGFACE_API_KEY"),
|
| 62 |
)
|
| 63 |
|
| 64 |
try:
|
| 65 |
response = client.text_generation(
|
| 66 |
prompt,
|
|
|
|
| 67 |
max_new_tokens=1024,
|
| 68 |
temperature=0.7,
|
| 69 |
)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print("Model call failed:", e)
|
| 72 |
+
return "⚠️ Sorry, the model could not generate an answer. Please try again."
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
# Extract text safely
|
| 75 |
if isinstance(response, str):
|
| 76 |
+
return response
|
| 77 |
elif isinstance(response, dict) and "generated_text" in response:
|
| 78 |
+
return response["generated_text"]
|
| 79 |
elif isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
|
| 80 |
+
return response[0]["generated_text"]
|
| 81 |
else:
|
| 82 |
+
return str(response)
|
|
|
|
|
|