Spaces:
Configuration error
Configuration error
| import lime | |
| from lime.lime_text import LimeTextExplainer | |
| from transformers import pipeline | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Move `classifier_fn` to the top level of the module for direct import | |
| def classifier_fn(texts): | |
| """ | |
| Takes a list of input texts, tokenizes them, and returns a 2D numpy array | |
| with probabilities for each emotion class. | |
| """ | |
| # Initialize tokenizer and model | |
| model_name = "bhadresh-savani/distilbert-base-uncased-emotion" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Tokenize input texts | |
| inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| # Pass through the model | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1).numpy() | |
| return probs | |
| class LIMEAnalyzer: | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.explainer = LimeTextExplainer(class_names=["NEGATIVE", "POSITIVE"]) | |
| def explain_prediction(self, text): | |
| """ | |
| Generate LIME explanations for the given text. | |
| Args: | |
| text (str): The input text to analyze. | |
| Returns: | |
| dict: LIME explanation results. | |
| """ | |
| sentiment_pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer) | |
| explanation = self.explainer.explain_instance( | |
| text, | |
| classifier_fn=classifier_fn, | |
| num_features=10 | |
| ) | |
| return explanation | |
| def explain_and_visualize(self, yaml_input_path, output_image_path): | |
| """ | |
| Generate LIME explanations and save visualizations for text in a YAML file. | |
| Args: | |
| yaml_input_path (str): Path to the input YAML file. | |
| output_image_path (str): Path to save the visualization image. | |
| """ | |
| # Load text from YAML file | |
| with open(yaml_input_path, 'r', encoding='utf-8') as file: | |
| data = yaml.safe_load(file) | |
| text = data.get('meta_analysis', {}).get('notes', '') | |
| if not text: | |
| print("No text found in the YAML file for analysis.") | |
| return | |
| # Generate LIME explanation | |
| explanation = self.explain_prediction(text) | |
| # Visualize explanation | |
| fig = explanation.as_pyplot_figure() | |
| plt.title("LIME Explanation") | |
| plt.savefig(output_image_path) | |
| plt.close() | |
| print(f"LIME explanation visualization saved to {output_image_path}") | |
| # Add a function to save LIME explanation as an image | |
| def save_lime_explanation(explanation, output_path): | |
| """ | |
| Save the LIME explanation as a bar chart image. | |
| Args: | |
| explanation: LIME explanation object. | |
| output_path (str): Path to save the image. | |
| """ | |
| fig = explanation.as_pyplot_figure() | |
| plt.title("LIME Explanation") | |
| plt.savefig(output_path) | |
| plt.close() | |
| print(f"LIME explanation saved to {output_path}") | |
| # Ensure `classifier_fn` is accessible for import | |
| __all__ = ['classifier_fn'] | |