iad-explainable-hf / core /explain.py
Parikshit Rathode
initial commit
c5732cc
"""
Explainability module using Gemini VLM.
This module provides functions to generate human-readable explanations
for detected anomalies using Google's Gemini Vision Language Model.
"""
from PIL import Image
import cv2
from google import genai
import numpy as np
# Model configuration
GEMINI_MODEL = "gemini-flash-lite-latest"
def get_explanation(
original_image: np.ndarray,
bboxes: list,
score: float,
category: str,
client
) -> str:
"""
Generate an explanation for the detected anomaly using Gemini VLM.
Args:
original_image: Original input image in RGB format
bboxes: List of bounding boxes [x1, y1, x2, y2] in 256x256 scale
score: Anomaly score
category: MVTec AD category
client: Initialized Gemini API client
Returns:
Explanation text from the model
"""
if not bboxes:
return "No anomaly detected."
# Scale bounding boxes from 256x256 to original image size
h_orig, w_orig = original_image.shape[:2]
scale_x = w_orig / 256.0
scale_y = h_orig / 256.0
# Draw red bounding boxes on a copy of the original image
annotated_img = original_image.copy()
for (x1, y1, x2, y2) in bboxes:
x1_s, y1_s = int(x1 * scale_x), int(y1 * scale_y)
x2_s, y2_s = int(x2 * scale_x), int(y2 * scale_y)
# Dynamic thickness based on image size
thickness = max(2, int(max(h_orig, w_orig) * 0.005))
cv2.rectangle(annotated_img, (x1_s, y1_s), (x2_s, y2_s), (255, 0, 0), thickness)
# Convert to PIL Image
annotated_pil = Image.fromarray(annotated_img)
# Construct prompt
prompt = f"""
You are an expert industrial quality control inspector.
We are inspecting a: {category}
An anomaly detection model has flagged a potential defect, highlighted by the RED BOUNDING BOX in the provided image.
Your task is to classify the defect inside the red box and assess its severity.
Common defects for {category} include: scratches, cuts, cracks, holes, structural damage, or severe discoloration.
Analyze the highlighted region carefully in the context of the whole object.
Only Provide your final assessment strictly in this format:
Defect: <Short name, e.g., Deep Scratch, Surface Cut, Crack, Contamination, Colouration>
Location: <Where is it on the object?>
Severity: <Low/Medium/High>
"""
# Generate response from Gemini with error handling
try:
response = client.models.generate_content(
model=GEMINI_MODEL,
contents=[prompt, annotated_pil]
)
return response.text
except Exception as e:
return f"Failed to generate explanation: {str(e)}"
def init_gemini_client(api_key: str):
"""
Initialize the Gemini API client.
Args:
api_key: Gemini API key
Returns:
Initialized genai client
"""
return genai.Client(api_key=api_key)