| from crewai.tools import BaseTool | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from google import genai | |
| from google.genai.types import Part, GenerateContentConfig | |
| import os | |
| import textwrap | |
| class ShapVisionToolSchema(BaseModel): | |
| target_variable: str | |
| image_path: str | |
| prompt: Optional[str] = None | |
| class ShapVisionTool(BaseTool): | |
| name: str = "shap_vision_tool" | |
| description: str = ( | |
| "Generates a detailed feature attribution explanation from a SHAP summary plot image, " | |
| "based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) " | |
| "using Gemini 2.5 Flash." | |
| ) | |
| args_schema: type = ShapVisionToolSchema | |
| metadata: dict = {} | |
| def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str: | |
| target_variable = target_variable.strip() | |
| api_key = self.metadata.get("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY not found in metadata.") | |
| client = genai.Client(api_key=api_key) | |
| system_prompt = ( | |
| "You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. " | |
| "You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations." | |
| ) | |
| if prompt is None: | |
| prompt = textwrap.dedent(f""" | |
| You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**. | |
| Your task is to analyze this SHAP summary plot and produce a detailed written report that includes: | |
| 1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure. | |
| 2. **The most important features** in determining {target_variable}, based on the plot. | |
| 3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}). | |
| 4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact). | |
| 5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters). | |
| 6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications). | |
| 7. **Caveats and limitations** in interpreting SHAP summary plots. | |
| Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation. | |
| """) | |
| ext = os.path.splitext(image_path)[-1].lower() | |
| if ext == ".png": | |
| mime_type = "image/png" | |
| elif ext in [".jpg", ".jpeg"]: | |
| mime_type = "image/jpeg" | |
| else: | |
| raise ValueError(f"Unsupported image type: {ext}") | |
| with open(image_path, "rb") as f: | |
| image_bytes = f.read() | |
| parts = [ | |
| Part.from_bytes(data=image_bytes, mime_type=mime_type), | |
| prompt | |
| ] | |
| response = client.models.generate_content( | |
| model="gemini-2.5-flash", | |
| contents=parts, | |
| config=GenerateContentConfig( | |
| system_instruction=system_prompt, | |
| temperature=0.4, | |
| max_output_tokens=2048 | |
| ) | |
| ) | |
| return response.text.strip() |