Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from wordcloud import WordCloud | |
| from collections import Counter, defaultdict | |
| import re | |
| import json | |
| import csv | |
| import io | |
| import tempfile | |
| from datetime import datetime | |
| import logging | |
| from functools import lru_cache, wraps | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Optional, Tuple, Any, Callable | |
| from contextlib import contextmanager | |
| import gc | |
| # Configuration | |
| class Config: | |
| MAX_HISTORY_SIZE: int = 1000 | |
| BATCH_SIZE_LIMIT: int = 50 | |
| MAX_TEXT_LENGTH: int = 512 | |
| MIN_WORD_LENGTH: int = 2 | |
| CACHE_SIZE: int = 128 | |
| BATCH_PROCESSING_SIZE: int = 8 | |
| # Visualization settings | |
| FIGURE_SIZE_SINGLE: Tuple[int, int] = (8, 5) | |
| FIGURE_SIZE_BATCH: Tuple[int, int] = (12, 8) | |
| WORDCLOUD_SIZE: Tuple[int, int] = (10, 5) | |
| THEMES = { | |
| 'default': {'pos': '#4ecdc4', 'neg': '#ff6b6b'}, | |
| 'ocean': {'pos': '#0077be', 'neg': '#ff6b35'}, | |
| 'forest': {'pos': '#228b22', 'neg': '#dc143c'}, | |
| 'sunset': {'pos': '#ff8c00', 'neg': '#8b0000'} | |
| } | |
| STOP_WORDS = { | |
| 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', | |
| 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', | |
| 'been', 'have', 'has', 'had', 'will', 'would', 'could', 'should' | |
| } | |
| config = Config() | |
| logger = logging.getLogger(__name__) | |
| # Decorators and Context Managers | |
| def handle_errors(default_return=None): | |
| """Centralized error handling decorator""" | |
| def decorator(func: Callable) -> Callable: | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| logger.error(f"{func.__name__} failed: {e}") | |
| return default_return if default_return is not None else f"Error: {str(e)}" | |
| return wrapper | |
| return decorator | |
| def managed_figure(*args, **kwargs): | |
| """Context manager for matplotlib figures to prevent memory leaks""" | |
| fig = plt.figure(*args, **kwargs) | |
| try: | |
| yield fig | |
| finally: | |
| plt.close(fig) | |
| gc.collect() | |
| class ThemeContext: | |
| """Theme management context""" | |
| def __init__(self, theme: str = 'default'): | |
| self.theme = theme | |
| self.colors = config.THEMES.get(theme, config.THEMES['default']) | |
| # Lazy Model Manager | |
| class ModelManager: | |
| """Lazy loading model manager""" | |
| _instance = None | |
| _model = None | |
| _tokenizer = None | |
| _device = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def model(self): | |
| if self._model is None: | |
| self._load_model() | |
| return self._model | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| self._load_model() | |
| return self._tokenizer | |
| def device(self): | |
| if self._device is None: | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| return self._device | |
| def _load_model(self): | |
| """Load model and tokenizer""" | |
| try: | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._tokenizer = BertTokenizer.from_pretrained("entropy25/sentimentanalysis") | |
| self._model = BertForSequenceClassification.from_pretrained("entropy25/sentimentanalysis") | |
| self._model.to(self._device) | |
| logger.info(f"Model loaded on {self._device}") | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| raise | |
| # Simplified Core Classes | |
| class TextProcessor: | |
| """Optimized text processing""" | |
| def clean_text(text: str) -> Tuple[str, ...]: | |
| """Single-pass text cleaning""" | |
| words = re.findall(r'\b\w{3,}\b', text.lower()) | |
| return tuple(w for w in words if w not in config.STOP_WORDS) | |
| class HistoryManager: | |
| """Simplified history management""" | |
| def __init__(self): | |
| self._history = [] | |
| def add(self, entry: Dict): | |
| self._history.append({**entry, 'timestamp': datetime.now().isoformat()}) | |
| if len(self._history) > config.MAX_HISTORY_SIZE: | |
| self._history = self._history[-config.MAX_HISTORY_SIZE:] | |
| def get_all(self) -> List[Dict]: | |
| return self._history.copy() | |
| def clear(self) -> int: | |
| count = len(self._history) | |
| self._history.clear() | |
| return count | |
| def size(self) -> int: | |
| return len(self._history) | |
| # Core Analysis Engine | |
| class SentimentEngine: | |
| """Streamlined sentiment analysis with attention-based keyword extraction""" | |
| def __init__(self): | |
| self.model_manager = ModelManager() | |
| def extract_key_words(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]: | |
| """Extract contributing words using BERT attention weights""" | |
| try: | |
| inputs = self.model_manager.tokenizer( | |
| text, return_tensors="pt", padding=True, | |
| truncation=True, max_length=config.MAX_TEXT_LENGTH | |
| ).to(self.model_manager.device) | |
| # Get model outputs with attention weights | |
| with torch.no_grad(): | |
| outputs = self.model_manager.model(**inputs, output_attentions=True) | |
| attention = outputs.attentions # Tuple of attention tensors for each layer | |
| # Use the last layer's attention, average over all heads | |
| last_attention = attention[-1] # Shape: [batch_size, num_heads, seq_len, seq_len] | |
| avg_attention = last_attention.mean(dim=1) # Average over heads: [batch_size, seq_len, seq_len] | |
| # Focus on attention to [CLS] token (index 0) as it represents the whole sequence | |
| cls_attention = avg_attention[0, 0, :] # Attention from CLS to all tokens | |
| # Get tokens and their attention scores | |
| tokens = self.model_manager.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| attention_scores = cls_attention.cpu().numpy() | |
| # Filter out special tokens and combine subword tokens | |
| word_scores = {} | |
| current_word = "" | |
| current_score = 0.0 | |
| for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
| if token in ['[CLS]', '[SEP]', '[PAD]']: | |
| continue | |
| if token.startswith('##'): | |
| # Subword token, add to current word | |
| current_word += token[2:] | |
| current_score = max(current_score, score) # Take max attention | |
| else: | |
| # New word, save previous if exists | |
| if current_word and len(current_word) >= config.MIN_WORD_LENGTH: | |
| word_scores[current_word.lower()] = current_score | |
| current_word = token | |
| current_score = score | |
| # Don't forget the last word | |
| if current_word and len(current_word) >= config.MIN_WORD_LENGTH: | |
| word_scores[current_word.lower()] = current_score | |
| # Filter out stop words and sort by attention score | |
| filtered_words = { | |
| word: score for word, score in word_scores.items() | |
| if word not in config.STOP_WORDS and len(word) >= config.MIN_WORD_LENGTH | |
| } | |
| # Sort by attention score and return top_k | |
| sorted_words = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_words[:top_k] | |
| except Exception as e: | |
| logger.error(f"Key word extraction failed: {e}") | |
| return [] | |
| def analyze_single(self, text: str) -> Dict: | |
| """Analyze single text with key word extraction""" | |
| if not text.strip(): | |
| raise ValueError("Empty text") | |
| inputs = self.model_manager.tokenizer( | |
| text, return_tensors="pt", padding=True, | |
| truncation=True, max_length=config.MAX_TEXT_LENGTH | |
| ).to(self.model_manager.device) | |
| with torch.no_grad(): | |
| outputs = self.model_manager.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| sentiment = "Positive" if probs[1] > probs[0] else "Negative" | |
| # Extract key contributing words | |
| key_words = self.extract_key_words(text) | |
| return { | |
| 'sentiment': sentiment, | |
| 'confidence': float(probs.max()), | |
| 'pos_prob': float(probs[1]), | |
| 'neg_prob': float(probs[0]), | |
| 'key_words': key_words | |
| } | |
| def analyze_batch(self, texts: List[str], progress_callback=None) -> List[Dict]: | |
| """Optimized batch processing with key words""" | |
| if len(texts) > config.BATCH_SIZE_LIMIT: | |
| texts = texts[:config.BATCH_SIZE_LIMIT] | |
| results = [] | |
| batch_size = config.BATCH_PROCESSING_SIZE | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i+batch_size] | |
| if progress_callback: | |
| progress_callback((i + len(batch)) / len(texts)) | |
| inputs = self.model_manager.tokenizer( | |
| batch, return_tensors="pt", padding=True, | |
| truncation=True, max_length=config.MAX_TEXT_LENGTH | |
| ).to(self.model_manager.device) | |
| with torch.no_grad(): | |
| outputs = self.model_manager.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| for text, prob in zip(batch, probs): | |
| sentiment = "Positive" if prob[1] > prob[0] else "Negative" | |
| # Extract key words for each text in batch | |
| key_words = self.extract_key_words(text, top_k=5) # Fewer for batch processing | |
| results.append({ | |
| 'text': text[:50] + '...' if len(text) > 50 else text, | |
| 'full_text': text, | |
| 'sentiment': sentiment, | |
| 'confidence': float(prob.max()), | |
| 'pos_prob': float(prob[1]), | |
| 'neg_prob': float(prob[0]), | |
| 'key_words': key_words | |
| }) | |
| return results | |
| # Unified Visualization System | |
| class PlotFactory: | |
| """Factory for creating plots with proper memory management""" | |
| def create_sentiment_bars(probs: np.ndarray, theme: ThemeContext) -> plt.Figure: | |
| """Create sentiment probability bars""" | |
| with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig: | |
| ax = fig.add_subplot(111) | |
| labels = ["Negative", "Positive"] | |
| colors = [theme.colors['neg'], theme.colors['pos']] | |
| bars = ax.bar(labels, probs, color=colors, alpha=0.8) | |
| ax.set_title("Sentiment Probabilities", fontweight='bold') | |
| ax.set_ylabel("Probability") | |
| ax.set_ylim(0, 1) | |
| # Add value labels | |
| for bar, prob in zip(bars, probs): | |
| ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02, | |
| f'{prob:.3f}', ha='center', va='bottom', fontweight='bold') | |
| fig.tight_layout() | |
| return fig | |
| def create_confidence_gauge(confidence: float, sentiment: str, theme: ThemeContext) -> plt.Figure: | |
| """Create confidence gauge""" | |
| with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig: | |
| ax = fig.add_subplot(111) | |
| # Create gauge | |
| theta = np.linspace(0, np.pi, 100) | |
| colors = [theme.colors['neg'] if i < 50 else theme.colors['pos'] for i in range(100)] | |
| for i in range(len(theta)-1): | |
| ax.fill_between([theta[i], theta[i+1]], [0, 0], [0.8, 0.8], | |
| color=colors[i], alpha=0.7) | |
| # Needle position | |
| pos = np.pi * (0.5 + (0.4 if sentiment == 'Positive' else -0.4) * confidence) | |
| ax.plot([pos, pos], [0, 0.6], 'k-', linewidth=6) | |
| ax.plot(pos, 0.6, 'ko', markersize=10) | |
| ax.set_xlim(0, np.pi) | |
| ax.set_ylim(0, 1) | |
| ax.set_title(f'{sentiment} - Confidence: {confidence:.3f}', fontweight='bold') | |
| ax.set_xticks([0, np.pi/2, np.pi]) | |
| ax.set_xticklabels(['Negative', 'Neutral', 'Positive']) | |
| ax.axis('off') | |
| fig.tight_layout() | |
| return fig | |
| def create_keyword_chart(key_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]: | |
| """Create horizontal bar chart for key contributing words""" | |
| if not key_words: | |
| return None | |
| with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig: | |
| ax = fig.add_subplot(111) | |
| words = [word for word, score in key_words] | |
| scores = [score for word, score in key_words] | |
| # Choose color based on sentiment | |
| color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg'] | |
| # Create horizontal bar chart | |
| bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7) | |
| ax.set_yticks(range(len(words))) | |
| ax.set_yticklabels(words) | |
| ax.set_xlabel('Attention Weight') | |
| ax.set_title(f'Top Contributing Words ({sentiment})', fontweight='bold') | |
| # Add value labels on bars | |
| for i, (bar, score) in enumerate(zip(bars, scores)): | |
| ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2., | |
| f'{score:.3f}', ha='left', va='center', fontsize=9) | |
| # Invert y-axis to show highest scoring word at top | |
| ax.invert_yaxis() | |
| ax.grid(axis='x', alpha=0.3) | |
| fig.tight_layout() | |
| return fig | |
| def create_wordcloud(text: str, sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]: | |
| """Create word cloud""" | |
| if len(text.split()) < 3: | |
| return None | |
| colormap = 'Greens' if sentiment == 'Positive' else 'Reds' | |
| wc = WordCloud(width=800, height=400, background_color='white', | |
| colormap=colormap, max_words=30).generate(text) | |
| with managed_figure(figsize=config.WORDCLOUD_SIZE) as fig: | |
| ax = fig.add_subplot(111) | |
| ax.imshow(wc, interpolation='bilinear') | |
| ax.axis('off') | |
| ax.set_title(f'{sentiment} Word Cloud', fontweight='bold') | |
| fig.tight_layout() | |
| return fig | |
| def create_batch_analysis(results: List[Dict], theme: ThemeContext) -> plt.Figure: | |
| """Create comprehensive batch visualization""" | |
| with managed_figure(figsize=config.FIGURE_SIZE_BATCH) as fig: | |
| gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3) | |
| # Sentiment distribution | |
| ax1 = fig.add_subplot(gs[0, 0]) | |
| sent_counts = Counter([r['sentiment'] for r in results]) | |
| colors = [theme.colors['pos'], theme.colors['neg']] | |
| ax1.pie(sent_counts.values(), labels=sent_counts.keys(), | |
| autopct='%1.1f%%', colors=colors[:len(sent_counts)]) | |
| ax1.set_title('Sentiment Distribution') | |
| # Confidence histogram | |
| ax2 = fig.add_subplot(gs[0, 1]) | |
| confs = [r['confidence'] for r in results] | |
| ax2.hist(confs, bins=8, alpha=0.7, color='skyblue', edgecolor='black') | |
| ax2.set_title('Confidence Distribution') | |
| ax2.set_xlabel('Confidence') | |
| # Sentiment over time | |
| ax3 = fig.add_subplot(gs[1, :]) | |
| pos_probs = [r['pos_prob'] for r in results] | |
| indices = range(len(results)) | |
| colors_scatter = [theme.colors['pos'] if r['sentiment'] == 'Positive' | |
| else theme.colors['neg'] for r in results] | |
| ax3.scatter(indices, pos_probs, c=colors_scatter, alpha=0.7, s=60) | |
| ax3.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5) | |
| ax3.set_title('Sentiment Progression') | |
| ax3.set_xlabel('Review Index') | |
| ax3.set_ylabel('Positive Probability') | |
| return fig | |
| # Unified Data Handler | |
| class DataHandler: | |
| """Handles all data operations""" | |
| def export_data(data: List[Dict], format_type: str) -> Tuple[Optional[str], str]: | |
| """Universal data export""" | |
| if not data: | |
| return None, "No data to export" | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, | |
| suffix=f'.{format_type}', encoding='utf-8') | |
| if format_type == 'csv': | |
| writer = csv.writer(temp_file) | |
| writer.writerow(['Timestamp', 'Text', 'Sentiment', 'Confidence', 'Pos_Prob', 'Neg_Prob', 'Key_Words']) | |
| for entry in data: | |
| writer.writerow([ | |
| entry.get('timestamp', ''), | |
| entry.get('text', ''), | |
| entry.get('sentiment', ''), | |
| f"{entry.get('confidence', 0):.4f}", | |
| f"{entry.get('pos_prob', 0):.4f}", | |
| f"{entry.get('neg_prob', 0):.4f}", | |
| "|".join([f"{word}:{score:.3f}" for word, score in entry.get('key_words', [])]) | |
| ]) | |
| elif format_type == 'json': | |
| json.dump(data, temp_file, indent=2, ensure_ascii=False) | |
| temp_file.close() | |
| return temp_file.name, f"Exported {len(data)} entries" | |
| def process_file(file) -> str: | |
| """Process uploaded file""" | |
| if not file: | |
| return "" | |
| content = file.read().decode('utf-8') | |
| if file.name.endswith('.csv'): | |
| lines = content.split('\n') | |
| reader = csv.reader(lines) | |
| try: | |
| headers = next(reader) | |
| review_idx = next((i for i, h in enumerate(headers) if 'review' in h.lower()), 0) | |
| return '\n'.join(row[review_idx] for row in reader if len(row) > review_idx) | |
| except: | |
| return '\n'.join(row[0] for row in reader if row) | |
| return content | |
| # Main Application | |
| class SentimentApp: | |
| """Main application orchestrator""" | |
| def __init__(self): | |
| self.engine = SentimentEngine() | |
| self.history = HistoryManager() | |
| self.data_handler = DataHandler() | |
| # Example data | |
| self.examples = [ | |
| ["The cinematography was stunning but the plot was predictable and lacked depth."], | |
| ["A masterpiece! Powerful performances and unforgettable scenes throughout."], | |
| ["Boring from start to finish with terrible acting and weak plot development."], | |
| ["Impressive effects but the story was confusing and difficult to follow."], | |
| ["Absolutely incredible ending - one of the best films in recent years!"] | |
| ] | |
| def analyze_single(self, text: str, theme: str = 'default'): | |
| """Single text analysis with key words""" | |
| if not text.strip(): | |
| return "Please enter text", None, None, None, None | |
| result = self.engine.analyze_single(text) | |
| # Add to history | |
| self.history.add({ | |
| 'text': text[:100], | |
| 'full_text': text, | |
| **result | |
| }) | |
| # Create visualizations | |
| theme_ctx = ThemeContext(theme) | |
| probs = np.array([result['neg_prob'], result['pos_prob']]) | |
| prob_plot = PlotFactory.create_sentiment_bars(probs, theme_ctx) | |
| gauge_plot = PlotFactory.create_confidence_gauge(result['confidence'], result['sentiment'], theme_ctx) | |
| cloud_plot = PlotFactory.create_wordcloud(text, result['sentiment'], theme_ctx) | |
| keyword_plot = PlotFactory.create_keyword_chart(result['key_words'], result['sentiment'], theme_ctx) | |
| # Format result text with key words | |
| key_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['key_words'][:5]]) | |
| result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n" | |
| f"Key Words: {key_words_str}") | |
| return result_text, prob_plot, gauge_plot, cloud_plot, keyword_plot | |
| def analyze_batch(self, reviews: str, progress=None): | |
| """Batch analysis""" | |
| if not reviews.strip(): | |
| return None | |
| texts = [r.strip() for r in reviews.split('\n') if r.strip()] | |
| if len(texts) < 2: | |
| return None | |
| results = self.engine.analyze_batch(texts, progress) | |
| # Add to history | |
| for result in results: | |
| self.history.add(result) | |
| # Create visualization | |
| theme_ctx = ThemeContext('default') | |
| return PlotFactory.create_batch_analysis(results, theme_ctx) | |
| def plot_history(self, theme: str = 'default'): | |
| """Plot analysis history""" | |
| history = self.history.get_all() | |
| if len(history) < 2: | |
| return None, f"Need at least 2 analyses for trends. Current: {len(history)}" | |
| theme_ctx = ThemeContext(theme) | |
| with managed_figure(figsize=(12, 8)) as fig: | |
| gs = fig.add_gridspec(2, 1, hspace=0.3) | |
| indices = list(range(len(history))) | |
| pos_probs = [item['pos_prob'] for item in history] | |
| confs = [item['confidence'] for item in history] | |
| # Sentiment trend | |
| ax1 = fig.add_subplot(gs[0, 0]) | |
| colors = [theme_ctx.colors['pos'] if p > 0.5 else theme_ctx.colors['neg'] | |
| for p in pos_probs] | |
| ax1.scatter(indices, pos_probs, c=colors, alpha=0.7, s=60) | |
| ax1.plot(indices, pos_probs, alpha=0.5, linewidth=2) | |
| ax1.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5) | |
| ax1.set_title('Sentiment History') | |
| ax1.set_ylabel('Positive Probability') | |
| ax1.grid(True, alpha=0.3) | |
| # Confidence trend | |
| ax2 = fig.add_subplot(gs[1, 0]) | |
| ax2.bar(indices, confs, alpha=0.7, color='lightblue', edgecolor='navy') | |
| ax2.set_title('Confidence Over Time') | |
| ax2.set_xlabel('Analysis Number') | |
| ax2.set_ylabel('Confidence') | |
| ax2.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| return fig, f"History: {len(history)} analyses" | |
| # Gradio Interface Setup | |
| def create_interface(): | |
| """Create streamlined Gradio interface""" | |
| app = SentimentApp() | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Movie Sentiment Analyzer") as demo: | |
| gr.Markdown("# 🎬 AI Movie Sentiment Analyzer") | |
| gr.Markdown("Optimized sentiment analysis with advanced visualizations and key word extraction") | |
| with gr.Tab("Single Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Movie Review", | |
| placeholder="Enter your movie review...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| theme_selector = gr.Dropdown( | |
| choices=list(config.THEMES.keys()), | |
| value="default", | |
| label="Theme" | |
| ) | |
| gr.Examples( | |
| examples=app.examples, | |
| inputs=text_input | |
| ) | |
| with gr.Column(): | |
| result_output = gr.Textbox(label="Result", lines=3) | |
| with gr.Row(): | |
| prob_plot = gr.Plot(label="Probabilities") | |
| gauge_plot = gr.Plot(label="Confidence") | |
| with gr.Row(): | |
| wordcloud_plot = gr.Plot(label="Word Cloud") | |
| keyword_plot = gr.Plot(label="Key Contributing Words") | |
| with gr.Tab("Batch Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File(label="Upload File", file_types=[".csv", ".txt"]) | |
| batch_input = gr.Textbox( | |
| label="Reviews (one per line)", | |
| lines=8 | |
| ) | |
| with gr.Column(): | |
| load_btn = gr.Button("Load File") | |
| batch_btn = gr.Button("Analyze Batch", variant="primary") | |
| batch_plot = gr.Plot(label="Batch Results") | |
| with gr.Tab("History & Export"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| clear_btn = gr.Button("Clear", variant="stop") | |
| status_btn = gr.Button("Status") | |
| with gr.Row(): | |
| csv_btn = gr.Button("Export CSV") | |
| json_btn = gr.Button("Export JSON") | |
| history_status = gr.Textbox(label="Status") | |
| history_plot = gr.Plot(label="History Trends") | |
| csv_file = gr.File(label="CSV Download", visible=True) | |
| json_file = gr.File(label="JSON Download", visible=True) | |
| # Event bindings | |
| analyze_btn.click( | |
| app.analyze_single, | |
| inputs=[text_input, theme_selector], | |
| outputs=[result_output, prob_plot, gauge_plot, wordcloud_plot, keyword_plot] | |
| ) | |
| load_btn.click(app.data_handler.process_file, inputs=file_upload, outputs=batch_input) | |
| batch_btn.click(app.analyze_batch, inputs=batch_input, outputs=batch_plot) | |
| refresh_btn.click( | |
| lambda theme: app.plot_history(theme), | |
| inputs=theme_selector, | |
| outputs=[history_plot, history_status] | |
| ) | |
| clear_btn.click( | |
| lambda: f"Cleared {app.history.clear()} entries", | |
| outputs=history_status | |
| ) | |
| status_btn.click( | |
| lambda: f"History: {app.history.size()} entries", | |
| outputs=history_status | |
| ) | |
| csv_btn.click( | |
| lambda: app.data_handler.export_data(app.history.get_all(), 'csv'), | |
| outputs=[csv_file, history_status] | |
| ) | |
| json_btn.click( | |
| lambda: app.data_handler.export_data(app.history.get_all(), 'json'), | |
| outputs=[json_file, history_status] | |
| ) | |
| return demo | |
| # Application Entry Point | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| demo = create_interface() | |
| demo.launch(share=True) |