samyakshrestha commited on
Commit
fa2cb8a
·
1 Parent(s): 7344bbc

Initial commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from src.pipeline import generate_report
4
+ from src.tools_loader import get_tools
5
+
6
+ # Pre-load models/tools once to avoid cold start delays
7
+ _ = get_tools()
8
+
9
+ def process_inputs(target_variable: str, image_path: str):
10
+ """Gradio callback to generate SHAP explanation report."""
11
+ if not image_path:
12
+ return "**Please upload a SHAP summary plot image to begin.**"
13
+ if not target_variable.strip():
14
+ return "**Please enter a target variable (e.g., life expectancy).**"
15
+
16
+ start = time.time()
17
+ report = generate_report(target_variable.strip(), image_path)
18
+ elapsed = time.time() - start
19
+
20
+ return f"""### SHAP Explanation Report for **{target_variable.strip()}**
21
+
22
+ {report}
23
+
24
+ ---
25
+ *Generated in {elapsed:.1f} seconds*
26
+ """
27
+
28
+ # Gradio App Interface
29
+ with gr.Blocks(
30
+ theme=gr.themes.Soft(),
31
+ title="SHAP Summary Plot Explainer",
32
+ css="""
33
+ .input-section { max-width: 600px; margin: 0 auto; }
34
+ .report-output { margin-top: 30px; }
35
+ """
36
+ ) as demo:
37
+
38
+ # Header
39
+ gr.Markdown("# SHAP Summary Plot Explainer\n\nUpload a SHAP plot and specify your prediction target to get a detailed explanation.")
40
+
41
+ with gr.Column(elem_classes=["input-section"]):
42
+ target_input = gr.Textbox(
43
+ label="Target Variable",
44
+ placeholder="e.g., life expectancy, credit score, disease risk..."
45
+ )
46
+ shap_image = gr.Image(
47
+ type="filepath",
48
+ label="Upload SHAP Summary Plot Image",
49
+ height=350
50
+ )
51
+ generate_button = gr.Button("Generate Explanation", variant="primary")
52
+
53
+ with gr.Column(elem_classes=["report-output"]):
54
+ report_output = gr.Markdown("**Awaiting input...**")
55
+
56
+ # Link inputs to callback
57
+ generate_button.click(
58
+ fn=process_inputs,
59
+ inputs=[target_input, shap_image],
60
+ outputs=report_output,
61
+ show_progress="full"
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ crewai
2
+ gradio>=4.27
3
+ google-genai
4
+ pydantic
src/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/__init__.py ADDED
File without changes
src/config/agents.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ shap_agent:
2
+ role: "SHAP Explanation Agent"
3
+ goal: >
4
+ Use the shap_vision_tool to analyze a SHAP summary plot image for a given target variable
5
+ and return back the exact output. DO NOT add, interpret, or speculate beyond what the shap_vision_tool outputs.
6
+ backstory: >
7
+ A world-class Explainable AI (XAI) researcher AI trained in interpreting SHAP summary plots
8
+ with precision and clarity. You specialize in transforming dense SHAP visualizations into
9
+ insightful, structured natural language explanations without introducing any hallucinations
10
+ or adding unverified inferences.
src/config/tasks.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ shap_task:
2
+ description: >
3
+ Analyze the SHAP summary plot image at '{image_path}' for the target variable '{target_variable}'.
4
+ Use the shap_vision_tool and return the exact output it generates.
5
+ DO NOT add or modify anything beyond what the tool returns.
6
+ expected_output: >
7
+ The verbatim output generated by the shap_vision_tool based on the image and target variable.
8
+ agent: shap_agent
src/crew.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from crewai import Agent, Crew, Process, Task, LLM
2
+ from crewai.project import CrewBase, agent, crew, task
3
+ from .tools_loader import get_tools
4
+
5
+ @CrewBase
6
+ class ShapCrew:
7
+ """SHAP explainer crew"""
8
+
9
+ def __init__(self):
10
+ # Load tools once
11
+ self.tools = get_tools()
12
+
13
+ # Initialize LLMs with optimal settings
14
+ self.llm = LLM(model="groq/meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0.3)
15
+
16
+ @agent
17
+ def shap_agent(self) -> Agent:
18
+ return Agent(
19
+ config=self.agents_config['shap_agent'],
20
+ tools=[self.tools["shap_tool"]],
21
+ llm=self.llm,
22
+ allow_delegation=False,
23
+ verbose=False
24
+ )
25
+
26
+ @task
27
+ def shap_task(self) -> Task:
28
+ return Task(config=self.tasks_config['shap_task'])
29
+
30
+ @crew
31
+ def crew(self) -> Crew:
32
+ return Crew(
33
+ agents=self.agents,
34
+ tasks=self.tasks,
35
+ verbose=False
36
+ )
src/pipeline.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .crew import ShapCrew
2
+
3
+ def generate_report(target_variable: str, image_path: str) -> str:
4
+ """Generate a SHAP explanation report based on user input"""
5
+ crew = ShapCrew().crew()
6
+ result = crew.kickoff(inputs={
7
+ "target_variable": target_variable,
8
+ "image_path": image_path
9
+ })
10
+ return str(result).strip()
src/tools/__init__.py ADDED
File without changes
src/tools/shap_vision_tool.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from crewai.tools import BaseTool
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ from google import genai
5
+ from google.genai.types import Part, GenerateContentConfig
6
+ import os
7
+ import textwrap
8
+
9
+ class ShapVisionToolSchema(BaseModel):
10
+ target_variable: str
11
+ image_path: str
12
+ prompt: Optional[str] = None
13
+
14
+ class ShapVisionTool(BaseTool):
15
+ name: str = "shap_vision_tool"
16
+ description: str = (
17
+ "Generates a detailed feature attribution explanation from a SHAP summary plot image, "
18
+ "based on a user-defined prediction target (e.g., life expectancy, credit risk, etc.) "
19
+ "using Gemini 2.5 Flash."
20
+ )
21
+ args_schema: type = ShapVisionToolSchema
22
+ metadata: dict = {}
23
+
24
+ def _run(self, target_variable: str, image_path: str, prompt: Optional[str] = None) -> str:
25
+ target_variable = target_variable.strip()
26
+
27
+ api_key = self.metadata.get("GEMINI_API_KEY")
28
+ if not api_key:
29
+ raise ValueError("GEMINI_API_KEY not found in metadata.")
30
+ client = genai.Client(api_key=api_key)
31
+
32
+ system_prompt = (
33
+ "You are a world-class explainable AI (XAI) researcher with deep expertise in interpreting SHAP summary plots. "
34
+ "You are known for your ability to translate dense SHAP visualizations into clear, insightful, and technically grounded explanations."
35
+ )
36
+
37
+ if prompt is None:
38
+ prompt = textwrap.dedent(f"""
39
+ You have been given a SHAP summary plot image that visualizes how different features impact the predictions of a model trained to estimate **{target_variable}**.
40
+
41
+ Your task is to analyze this SHAP summary plot and produce a detailed written report that includes:
42
+
43
+ 1. **What the SHAP summary plot represents**, including color meaning, axis explanation, and general structure.
44
+ 2. **The most important features** in determining {target_variable}, based on the plot.
45
+ 3. **Direction of influence** for each top feature (e.g., high values of poverty_rate decrease predicted {target_variable}).
46
+ 4. **Shape, spread, and variability** of SHAP distributions for top features (e.g., stable effect vs. heterogeneous impact).
47
+ 5. **Interesting patterns** (e.g., non-linear effects, counterintuitive findings, wide spreads, or sharp clusters).
48
+ 6. **Interpretation of results** in real-world terms (socioeconomic, environmental, or demographic implications).
49
+ 7. **Caveats and limitations** in interpreting SHAP summary plots.
50
+
51
+ Be analytical, structured, and use your expertise to interpret the image intelligently, not just describe it. Write the output as if preparing it for a technical report or presentation.
52
+ """)
53
+
54
+ ext = os.path.splitext(image_path)[-1].lower()
55
+ if ext == ".png":
56
+ mime_type = "image/png"
57
+ elif ext in [".jpg", ".jpeg"]:
58
+ mime_type = "image/jpeg"
59
+ else:
60
+ raise ValueError(f"Unsupported image type: {ext}")
61
+
62
+ with open(image_path, "rb") as f:
63
+ image_bytes = f.read()
64
+
65
+ parts = [
66
+ Part.from_bytes(data=image_bytes, mime_type=mime_type),
67
+ prompt
68
+ ]
69
+
70
+ response = client.models.generate_content(
71
+ model="gemini-2.5-flash",
72
+ contents=parts,
73
+ config=GenerateContentConfig(
74
+ system_instruction=system_prompt,
75
+ temperature=0.4,
76
+ max_output_tokens=2048
77
+ )
78
+ )
79
+
80
+ return response.text.strip()
src/tools_loader.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from .tools.shap_vision_tool import ShapVisionTool # Import SHAP vision tool
4
+
5
+ def get_tools():
6
+ """Create and return all configured tools"""
7
+
8
+ # Get paths and API keys
9
+ groq_key = os.getenv("GROQ_API_KEY") # Retrieve GROQ API key from environment
10
+ gemini_key = os.getenv("GEMINI_API_KEY") # Retrieve Gemini API key from environment
11
+
12
+ # Create tool
13
+ shap_tool = ShapVisionTool(metadata={"GEMINI_API_KEY": gemini_key}) # Initialize SHAP vision tool with Gemini API key
14
+
15
+ # Return all tools in a dictionary
16
+ return {"vision_tool": shap_tool}