File size: 3,620 Bytes
fa2cb8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from crewai.tools import BaseTool
from pydantic import BaseModel
from typing import Optional
from google import genai
from google.genai.types import Part, GenerateContentConfig
import os
import textwrap

class ShapVisionToolSchema(BaseModel):
    target_variable: str
    image_path: str
    prompt: Optional[str] = None

class ShapVisionTool(BaseTool):
    name: str = "shap_vision_tool"
    description: str = (
        "Generates a detailed feature attribution explanation from a SHAP summary plot image, "
        "based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) "
        "using Gemini 2.5 Flash."
    )
    args_schema: type = ShapVisionToolSchema
    metadata: dict = {}

    def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str:
        target_variable = target_variable.strip()

        api_key = self.metadata.get("GEMINI_API_KEY")
        if not api_key:
            raise ValueError("GEMINI_API_KEY not found in metadata.")
        client = genai.Client(api_key=api_key)

        system_prompt = (
            "You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. "
            "You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations."
        )

        if prompt is None:
            prompt = textwrap.dedent(f"""
                You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**.

                Your task is to analyze this SHAP summary plot and produce a detailed written report that includes:

                1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure.
                2. **The most important features** in determining {target_variable}, based on the plot.
                3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}).
                4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact).
                5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters).
                6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications).
                7. **Caveats and limitations** in interpreting SHAP summary plots.

                Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation.
            """)

        ext = os.path.splitext(image_path)[-1].lower()
        if ext == ".png":
            mime_type = "image/png"
        elif ext in [".jpg", ".jpeg"]:
            mime_type = "image/jpeg"
        else:
            raise ValueError(f"Unsupported image type: {ext}")

        with open(image_path, "rb") as f:
            image_bytes = f.read()

        parts = [
            Part.from_bytes(data=image_bytes, mime_type=mime_type),
            prompt
        ]

        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=parts,
            config=GenerateContentConfig(
                system_instruction=system_prompt,
                temperature=0.4,
                max_output_tokens=2048
            )
        )

        return response.text.strip()