Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import json | |
| import numpy as np | |
| from functools import lru_cache | |
| # Cache the model loading | |
| def load_model(): | |
| model_path = "MMADS/MoralFoundationsClassifier" | |
| model = RobertaForSequenceClassification.from_pretrained(model_path) | |
| tokenizer = RobertaTokenizer.from_pretrained(model_path) | |
| # Load label names | |
| label_names = [ | |
| "care_virtue", "care_vice", | |
| "fairness_virtue", "fairness_vice", | |
| "loyalty_virtue", "loyalty_vice", | |
| "authority_virtue", "authority_vice", | |
| "sanctity_virtue", "sanctity_vice" | |
| ] | |
| return model, tokenizer, label_names | |
| def predict_batch(texts, model, tokenizer, label_names): | |
| """Process texts in batch for efficiency""" | |
| # Tokenize all texts at once | |
| inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = torch.sigmoid(outputs.logits) | |
| # Convert to numpy array | |
| predictions = predictions.numpy() | |
| # Create results for each text | |
| results = [] | |
| for i, text in enumerate(texts): | |
| scores = {label: float(predictions[i, j]) for j, label in enumerate(label_names)} | |
| results.append({ | |
| 'text': text, | |
| 'scores': scores | |
| }) | |
| return results | |
| def create_visualization(results): | |
| """Create visualization for moral foundation scores""" | |
| if not results: | |
| return None | |
| # Aggregate scores across all texts | |
| all_scores = {} | |
| for label in results[0]['scores'].keys(): | |
| all_scores[label] = [r['scores'][label] for r in results] | |
| # Create grouped bar chart | |
| foundations = ['care', 'fairness', 'loyalty', 'authority', 'sanctity'] | |
| virtues = [] | |
| vices = [] | |
| for foundation in foundations: | |
| virtue_scores = all_scores[f"{foundation}_virtue"] | |
| vice_scores = all_scores[f"{foundation}_vice"] | |
| virtues.append(np.mean(virtue_scores)) | |
| vices.append(np.mean(vice_scores)) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| name='Virtues', | |
| x=foundations, | |
| y=virtues, | |
| marker_color='lightgreen' | |
| )) | |
| fig.add_trace(go.Bar( | |
| name='Vices', | |
| x=foundations, | |
| y=vices, | |
| marker_color='lightcoral' | |
| )) | |
| fig.update_layout( | |
| title="Average Moral Foundation Scores", | |
| xaxis_title="Moral Foundations", | |
| yaxis_title="Average Score", | |
| barmode='group', | |
| yaxis=dict(range=[0, 1]), | |
| template="plotly_white" | |
| ) | |
| return fig | |
| def create_heatmap(results): | |
| """Create heatmap visualization""" | |
| if not results: | |
| return None | |
| # Create matrix for heatmap | |
| texts = [r['text'][:50] + "..." if len(r['text']) > 50 else r['text'] for r in results] | |
| labels = list(results[0]['scores'].keys()) | |
| matrix = [] | |
| for result in results: | |
| matrix.append([result['scores'][label] for label in labels]) | |
| fig = px.imshow( | |
| matrix, | |
| labels=dict(x="Moral Foundations", y="Texts", color="Score"), | |
| x=labels, | |
| y=texts, | |
| aspect="auto", | |
| color_continuous_scale="RdBu_r" | |
| ) | |
| fig.update_layout( | |
| title="Moral Foundation Scores Heatmap", | |
| height=max(400, len(texts) * 30) | |
| ) | |
| return fig | |
| def process_text(text): | |
| """Process single text input""" | |
| model, tokenizer, label_names = load_model() | |
| results = predict_batch([text], model, tokenizer, label_names) | |
| # Format output | |
| scores_text = "**Moral Foundation Scores:**\n\n" | |
| for label, score in results[0]['scores'].items(): | |
| foundation = label.replace('_', ' ').title() | |
| scores_text += f"{foundation}: {score:.4f}\n" | |
| # Create visualizations | |
| bar_chart = create_visualization(results) | |
| return scores_text, bar_chart | |
| def process_csv(file, progress=gr.Progress()): | |
| """Process CSV file with multiple texts""" | |
| if file is None: | |
| return "Please upload a CSV file", None, None, None | |
| try: | |
| # Read CSV | |
| df = pd.read_csv(file.name) | |
| if 'text' not in df.columns: | |
| return "Error: CSV must contain a 'text' column", None, None, None | |
| texts = df['text'].tolist() | |
| # Load model and process in batches | |
| progress(0, desc="Loading model...") | |
| model, tokenizer, label_names = load_model() | |
| # Process in batches of 32 | |
| batch_size = 32 | |
| all_results = [] | |
| total_batches = (len(texts) + batch_size - 1) // batch_size | |
| for i in range(0, len(texts), batch_size): | |
| batch_num = i // batch_size + 1 | |
| progress(batch_num / total_batches, desc=f"Processing batch {batch_num}/{total_batches}") | |
| batch_texts = texts[i:i+batch_size] | |
| batch_results = predict_batch(batch_texts, model, tokenizer, label_names) | |
| all_results.extend(batch_results) | |
| progress(0.9, desc="Creating visualizations...") | |
| # Create summary | |
| summary = f"**Processed {len(texts)} texts**\n\n" | |
| summary += "**Average Scores Across All Texts:**\n\n" | |
| # Calculate average scores | |
| avg_scores = {} | |
| for label in label_names: | |
| avg_scores[label] = np.mean([r['scores'][label] for r in all_results]) | |
| summary += f"{label.replace('_', ' ').title()}: {avg_scores[label]:.4f}\n" | |
| # Create visualizations | |
| bar_chart = create_visualization(all_results) | |
| heatmap = create_heatmap(all_results[:20]) # Limit heatmap to first 20 texts | |
| # Create downloadable results | |
| results_df = pd.DataFrame([ | |
| { | |
| 'text': r['text'], | |
| **r['scores'] | |
| } for r in all_results | |
| ]) | |
| # Save to a temporary file and return the path | |
| output_path = "results.csv" | |
| results_df.to_csv(output_path, index=False) | |
| return summary, bar_chart, heatmap, output_path | |
| except Exception as e: | |
| return f"Error processing CSV: {str(e)}", None, None, None | |
| # Create example texts | |
| example_texts = [ | |
| "We must protect the vulnerable and care for those who cannot care for themselves.", | |
| "Everyone deserves equal treatment under the law, regardless of their background.", | |
| "Betraying your country is one of the worst things a person can do.", | |
| "We should respect our elders and follow traditional values.", | |
| "Some things are sacred and should not be violated or mocked." | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Moral Foundations Classifier") as demo: | |
| gr.Markdown(""" | |
| # Moral Foundations Classifier | |
| This app analyzes text for moral foundations based on Moral Foundations Theory. | |
| It identifies five moral foundations (each with virtue and vice dimensions): | |
| - **Care/Harm**: Compassion and protection vs. harm | |
| - **Fairness/Cheating**: Justice and equality vs. cheating | |
| - **Loyalty/Betrayal**: Group loyalty vs. betrayal | |
| - **Authority/Subversion**: Respect for authority vs. subversion | |
| - **Sanctity/Degradation**: Purity and sanctity vs. degradation | |
| """) | |
| with gr.Tab("Single Text Analysis"): | |
| text_input = gr.Textbox( | |
| label="Enter text to analyze", | |
| placeholder="Type or paste your text here...", | |
| lines=5 | |
| ) | |
| gr.Examples( | |
| examples=example_texts, | |
| inputs=text_input, | |
| label="Example Texts" | |
| ) | |
| analyze_btn = gr.Button("Analyze Text", variant="primary") | |
| with gr.Row(): | |
| scores_output = gr.Markdown(label="Scores") | |
| chart_output = gr.Plot(label="Visualization") | |
| analyze_btn.click( | |
| fn=process_text, | |
| inputs=text_input, | |
| outputs=[scores_output, chart_output] | |
| ) | |
| with gr.Tab("Batch Analysis (CSV)"): | |
| gr.Markdown(""" | |
| Upload a CSV file with a 'text' column containing the texts to analyze. | |
| The app will process all texts and provide aggregate visualizations. | |
| A sample CSV file is available for download <a href="https://huggingface.co/spaces/MMADS/MoralFoundationsClassifier-app/tree/main/examples" target="_blank" rel="noopener noreferrer">here</a>. | |
| """) | |
| csv_input = gr.File( | |
| label="Upload CSV file", | |
| file_types=[".csv"] | |
| ) | |
| process_btn = gr.Button("Process CSV", variant="primary") | |
| summary_output = gr.Markdown(label="Summary") | |
| with gr.Row(): | |
| bar_output = gr.Plot(label="Average Scores") | |
| heatmap_output = gr.Plot(label="Scores Heatmap (First 20 texts)") | |
| # Add download component | |
| download_output = gr.File(label="Download Results", visible=True) | |
| process_btn.click( | |
| fn=process_csv, | |
| inputs=csv_input, | |
| outputs=[summary_output, bar_output, heatmap_output, download_output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| Based on the [MoralFoundationsClassifier](https://huggingface.co/MMADS/MoralFoundationsClassifier) by M. Murat Ardag | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |