from crewai.tools import BaseTool from pydantic import BaseModel from typing import Optional from google import genai from google.genai.types import Part, GenerateContentConfig import os import textwrap class ShapVisionToolSchema(BaseModel): target_variable: str image_path: str prompt: Optional[str] = None class ShapVisionTool(BaseTool): name: str = "shap_vision_tool" description: str = ( "Generates a detailed feature attribution explanation from a SHAP summary plot image, " "based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) " "using Gemini 2.5 Flash." ) args_schema: type = ShapVisionToolSchema metadata: dict = {} def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str: target_variable = target_variable.strip() api_key = self.metadata.get("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY not found in metadata.") client = genai.Client(api_key=api_key) system_prompt = ( "You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. " "You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations." ) if prompt is None: prompt = textwrap.dedent(f""" You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**. Your task is to analyze this SHAP summary plot and produce a detailed written report that includes: 1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure. 2. **The most important features** in determining {target_variable}, based on the plot. 3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}). 4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact). 5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters). 6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications). 7. **Caveats and limitations** in interpreting SHAP summary plots. Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation. """) ext = os.path.splitext(image_path)[-1].lower() if ext == ".png": mime_type = "image/png" elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg" else: raise ValueError(f"Unsupported image type: {ext}") with open(image_path, "rb") as f: image_bytes = f.read() parts = [ Part.from_bytes(data=image_bytes, mime_type=mime_type), prompt ] response = client.models.generate_content( model="gemini-2.5-flash", contents=parts, config=GenerateContentConfig( system_instruction=system_prompt, temperature=0.4, max_output_tokens=2048 ) ) return response.text.strip()