SHAP-plot-interpreter / src /tools /shap_vision_tool.py
samyakshrestha's picture
Initial commit
fa2cb8a
from crewai.tools import BaseTool
from pydantic import BaseModel
from typing import Optional
from google import genai
from google.genai.types import Part, GenerateContentConfig
import os
import textwrap
class ShapVisionToolSchema(BaseModel):
target_variable: str
image_path: str
prompt: Optional[str] = None
class ShapVisionTool(BaseTool):
name: str = "shap_vision_tool"
description: str = (
"Generates a detailed feature attribution explanation from a SHAP summary plot image, "
"based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) "
"using Gemini 2.5 Flash."
)
args_schema: type = ShapVisionToolSchema
metadata: dict = {}
def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str:
target_variable = target_variable.strip()
api_key = self.metadata.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY not found in metadata.")
client = genai.Client(api_key=api_key)
system_prompt = (
"You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. "
"You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations."
)
if prompt is None:
prompt = textwrap.dedent(f"""
You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**.
Your task is to analyze this SHAP summary plot and produce a detailed written report that includes:
1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure.
2. **The most important features** in determining {target_variable}, based on the plot.
3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}).
4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact).
5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters).
6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications).
7. **Caveats and limitations** in interpreting SHAP summary plots.
Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation.
""")
ext = os.path.splitext(image_path)[-1].lower()
if ext == ".png":
mime_type = "image/png"
elif ext in [".jpg", ".jpeg"]:
mime_type = "image/jpeg"
else:
raise ValueError(f"Unsupported image type: {ext}")
with open(image_path, "rb") as f:
image_bytes = f.read()
parts = [
Part.from_bytes(data=image_bytes, mime_type=mime_type),
prompt
]
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=parts,
config=GenerateContentConfig(
system_instruction=system_prompt,
temperature=0.4,
max_output_tokens=2048
)
)
return response.text.strip()