File size: 3,620 Bytes
fa2cb8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from crewai.tools import BaseTool
from pydantic import BaseModel
from typing import Optional
from google import genai
from google.genai.types import Part, GenerateContentConfig
import os
import textwrap
class ShapVisionToolSchema(BaseModel):
target_variable: str
image_path: str
prompt: Optional[str] = None
class ShapVisionTool(BaseTool):
name: str = "shap_vision_tool"
description: str = (
"Generates a detailed feature attribution explanation from a SHAP summary plot image, "
"based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) "
"using Gemini 2.5 Flash."
)
args_schema: type = ShapVisionToolSchema
metadata: dict = {}
def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str:
target_variable = target_variable.strip()
api_key = self.metadata.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY not found in metadata.")
client = genai.Client(api_key=api_key)
system_prompt = (
"You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. "
"You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations."
)
if prompt is None:
prompt = textwrap.dedent(f"""
You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**.
Your task is to analyze this SHAP summary plot and produce a detailed written report that includes:
1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure.
2. **The most important features** in determining {target_variable}, based on the plot.
3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}).
4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact).
5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters).
6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications).
7. **Caveats and limitations** in interpreting SHAP summary plots.
Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation.
""")
ext = os.path.splitext(image_path)[-1].lower()
if ext == ".png":
mime_type = "image/png"
elif ext in [".jpg", ".jpeg"]:
mime_type = "image/jpeg"
else:
raise ValueError(f"Unsupported image type: {ext}")
with open(image_path, "rb") as f:
image_bytes = f.read()
parts = [
Part.from_bytes(data=image_bytes, mime_type=mime_type),
prompt
]
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=parts,
config=GenerateContentConfig(
system_instruction=system_prompt,
temperature=0.4,
max_output_tokens=2048
)
)
return response.text.strip() |