Aethero_github / Aethero_App /lime_integration.py
xvadur's picture
Add complete Aethero_App and aethero_protocol directories
46f737d
import lime
from lime.lime_text import LimeTextExplainer
from transformers import pipeline
import yaml
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Move `classifier_fn` to the top level of the module for direct import
def classifier_fn(texts):
"""
Takes a list of input texts, tokenizes them, and returns a 2D numpy array
with probabilities for each emotion class.
"""
# Initialize tokenizer and model
model_name = "bhadresh-savani/distilbert-base-uncased-emotion"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Tokenize input texts
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Pass through the model
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).numpy()
return probs
class LIMEAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.explainer = LimeTextExplainer(class_names=["NEGATIVE", "POSITIVE"])
def explain_prediction(self, text):
"""
Generate LIME explanations for the given text.
Args:
text (str): The input text to analyze.
Returns:
dict: LIME explanation results.
"""
sentiment_pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer)
explanation = self.explainer.explain_instance(
text,
classifier_fn=classifier_fn,
num_features=10
)
return explanation
def explain_and_visualize(self, yaml_input_path, output_image_path):
"""
Generate LIME explanations and save visualizations for text in a YAML file.
Args:
yaml_input_path (str): Path to the input YAML file.
output_image_path (str): Path to save the visualization image.
"""
# Load text from YAML file
with open(yaml_input_path, 'r', encoding='utf-8') as file:
data = yaml.safe_load(file)
text = data.get('meta_analysis', {}).get('notes', '')
if not text:
print("No text found in the YAML file for analysis.")
return
# Generate LIME explanation
explanation = self.explain_prediction(text)
# Visualize explanation
fig = explanation.as_pyplot_figure()
plt.title("LIME Explanation")
plt.savefig(output_image_path)
plt.close()
print(f"LIME explanation visualization saved to {output_image_path}")
# Add a function to save LIME explanation as an image
def save_lime_explanation(explanation, output_path):
"""
Save the LIME explanation as a bar chart image.
Args:
explanation: LIME explanation object.
output_path (str): Path to save the image.
"""
fig = explanation.as_pyplot_figure()
plt.title("LIME Explanation")
plt.savefig(output_path)
plt.close()
print(f"LIME explanation saved to {output_path}")
# Ensure `classifier_fn` is accessible for import
__all__ = ['classifier_fn']