File size: 2,981 Bytes
8ebabee
c23a150
 
 
 
 
 
 
 
 
 
 
 
 
 
808f1c2
c23a150
 
f057e48
e89bea7
c23a150
 
e89bea7
f057e48
e89bea7
f057e48
 
 
 
e89bea7
f057e48
e89bea7
f057e48
 
 
 
 
c23a150
 
 
e89bea7
c23a150
 
37a54d8
 
 
c23a150
37a54d8
 
c23a150
37a54d8
 
c23a150
37a54d8
 
c23a150
e89bea7
7a5fdf7
f057e48
7a5fdf7
68dc9ea
7a5fdf7
9478767
808f1c2
 
 
 
 
 
f057e48
 
 
da87d84
f057e48
808f1c2
f057e48
808f1c2
f057e48
808f1c2
f057e48
808f1c2
f057e48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from huggingface_hub import InferenceClient
import os
import pandas as pd

def summarize_dataframe(df: pd.DataFrame, max_rows: int = 30) -> str:
    summary = f"Columns: {', '.join(df.columns)}\n\n"
    if len(df) > max_rows:
        sample = df.sample(max_rows, random_state=42)
        summary += "Showing a random sample of rows:\n"
    else:
        sample = df
        summary += "Showing all rows:\n"
    summary += sample.to_string(index=False)
    return summary


def query_agent(df: pd.DataFrame, query: str) -> str:
    query_lower = query.lower()

    # ----------------- Direct Analysis for Most Common -----------------
    try:
        if "most common" in query_lower or "most frequent" in query_lower:
            # Look for multiple columns in query
            cols_in_query = [col for col in df.columns if col.lower() in query_lower]
            
            if len(cols_in_query) == 1:
                col = cols_in_query[0]
                value = df[col].mode()[0]
                return f"The most common value in column '{col}' is '{value}'."
            
            elif len(cols_in_query) > 1:
                # Compute most common combination of values across the columns
                combo_series = df[cols_in_query].apply(lambda row: tuple(row), axis=1)
                most_common_combo = combo_series.mode()[0]
                combo_str = ", ".join(f"{col}={val}" for col, val in zip(cols_in_query, most_common_combo))
                return f"The most common combination of values is: {combo_str}"

    except Exception as e:
        print("Direct analysis failed:", e)

    # ----------------- Use LLM if direct analysis fails -----------------
    data_text = summarize_dataframe(df)
    prompt = f"""
You are a data analysis assistant with expertise in statistics and data interpretation.
Analyze the dataset sample below and answer the user's question in a clear, detailed, and well-explained way.
Include both the direct answer and a short explanation or reasoning behind it.

Dataset Summary:
{data_text}

Question:
{query}

Answer (with explanation):
"""

    # Initialize client with explicit provider
    client = InferenceClient(
        model="google/gemma-2b-it",
        provider="hf-inference",
        token=os.environ.get("HUGGINGFACE_API_TOKEN"),
    )

    try:
        response = client.text_generation(
            prompt,
            max_new_tokens=1024,
            temperature=0.7,
        )
    except Exception as e:
        print("Model call failed:", e)
        return "⚠️ Sorry, the model could not generate an answer. Please try again."

    # Extract text safely
    if isinstance(response, str):
        return response
    elif isinstance(response, dict) and "generated_text" in response:
        return response["generated_text"]
    elif isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
        return response[0]["generated_text"]
    else:
        return str(response)