| import gradio as gr |
| import torch |
| import numpy as np |
| import os |
| import json |
| from model.inference_optimized import OptimizedToxicityClassifier |
| import matplotlib.pyplot as plt |
| from typing import List, Dict |
| import langid |
| import pandas as pd |
|
|
| |
| ONNX_MODEL_PATH = os.environ.get("ONNX_MODEL_PATH", "weights/toxic_classifier.onnx") |
| PYTORCH_MODEL_PATH = os.environ.get("PYTORCH_MODEL_PATH", "weights/toxic_classifier_xlm-roberta-large/pytorch_model.bin") |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| SUPPORTED_LANGUAGES = { |
| 'en': 'English', |
| 'ru': 'Russian', |
| 'tr': 'Turkish', |
| 'es': 'Spanish', |
| 'fr': 'French', |
| 'it': 'Italian', |
| 'pt': 'Portuguese' |
| } |
|
|
| |
| try: |
| if os.path.exists(ONNX_MODEL_PATH): |
| classifier = OptimizedToxicityClassifier(onnx_path=ONNX_MODEL_PATH, device=DEVICE) |
| print(f"Loaded ONNX model from {ONNX_MODEL_PATH}") |
| else: |
| classifier = OptimizedToxicityClassifier(pytorch_path=PYTORCH_MODEL_PATH, device=DEVICE) |
| print(f"Loaded PyTorch model from {PYTORCH_MODEL_PATH}") |
| except Exception as e: |
| print(f"Error loading model: {str(e)}") |
| classifier = None |
|
|
| def detect_language(text: str) -> str: |
| """Detect language of input text""" |
| try: |
| lang, _ = langid.classify(text) |
| return lang if lang in SUPPORTED_LANGUAGES else 'en' |
| except: |
| return 'en' |
|
|
| def predict_toxicity(text: str, selected_language: str = None) -> Dict: |
| """Predict toxicity of input text""" |
| if not text or not text.strip(): |
| return { |
| "error": "Please enter some text to analyze.", |
| "html_result": "<div class='error'>Please enter some text to analyze.</div>" |
| } |
| |
| if classifier is None: |
| return { |
| "error": "Model not loaded. Please check logs.", |
| "html_result": "<div class='error'>Model not loaded. Please check logs.</div>" |
| } |
| |
| |
| if not selected_language or selected_language == "Auto-detect": |
| lang_code = detect_language(text) |
| detected = True |
| else: |
| |
| lang_code = next((code for code, name in SUPPORTED_LANGUAGES.items() |
| if name == selected_language), 'en') |
| detected = False |
| |
| |
| try: |
| results = classifier.predict([text], langs=[lang_code])[0] |
| |
| |
| probs = results["probabilities"] |
| sorted_categories = sorted( |
| [(label, probs[label]) for label in probs], |
| key=lambda x: x[1], |
| reverse=True |
| ) |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| labels = [label.replace('_', ' ').title() for label, _ in sorted_categories] |
| values = [prob * 100 for _, prob in sorted_categories] |
| colors = ['#ff6b6b' if val >= 50 else '#74c0fc' for val in values] |
| |
| ax.barh(labels, values, color=colors) |
| ax.set_xlim(0, 100) |
| ax.set_xlabel('Probability (%)') |
| ax.set_title('Toxicity Analysis') |
| ax.grid(axis='x', linestyle='--', alpha=0.7) |
| |
| |
| for i, v in enumerate(values): |
| ax.text(v + 1, i, f'{v:.1f}%', va='center') |
| |
| |
| lang_display = SUPPORTED_LANGUAGES.get(lang_code, lang_code) |
| overall_result = "TOXIC" if results["is_toxic"] else "NON-TOXIC" |
| result_color = "#ff6b6b" if results["is_toxic"] else "#66d9e8" |
| |
| html_result = f""" |
| <div style='margin-bottom: 20px;'> |
| <h2>Analysis Result: <span style='color: {result_color};'>{overall_result}</span></h2> |
| <h3>Language: {lang_display} {'(detected)' if detected else ''}</h3> |
| </div> |
| <div style='margin-bottom: 10px;'> |
| <table width='100%' style='border-collapse: collapse;'> |
| <tr style='background-color: #e9ecef; font-weight: bold;'> |
| <th style='padding: 8px; text-align: left; border: 1px solid #dee2e6;'>Category</th> |
| <th style='padding: 8px; text-align: right; border: 1px solid #dee2e6;'>Probability</th> |
| <th style='padding: 8px; text-align: center; border: 1px solid #dee2e6;'>Status</th> |
| </tr> |
| """ |
| |
| |
| for label, prob in sorted_categories: |
| formatted_label = label.replace('_', ' ').title() |
| status = "DETECTED" if prob >= 0.5 else "Not Detected" |
| status_color = "#ff6b6b" if prob >= 0.5 else "#66d9e8" |
| prob_percent = f"{prob * 100:.1f}%" |
| |
| html_result += f""" |
| <tr> |
| <td style='padding: 8px; border: 1px solid #dee2e6;'>{formatted_label}</td> |
| <td style='padding: 8px; text-align: right; border: 1px solid #dee2e6;'>{prob_percent}</td> |
| <td style='padding: 8px; text-align: center; border: 1px solid #dee2e6; color: {status_color}; font-weight: bold;'>{status}</td> |
| </tr> |
| """ |
| |
| html_result += "</table></div>" |
| |
| |
| if results["is_toxic"]: |
| toxic_categories = [cat.replace('_', ' ').title() for cat in results["toxic_categories"]] |
| categories_list = ", ".join(toxic_categories) |
| html_result += f""" |
| <div style='margin-top: 10px;'> |
| <p><strong>Detected toxic categories:</strong> {categories_list}</p> |
| </div> |
| """ |
| |
| return { |
| "prediction": results, |
| "html_result": html_result, |
| "fig": fig |
| } |
| |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return { |
| "error": f"Error processing text: {str(e)}", |
| "html_result": f"<div class='error'>Error processing text: {str(e)}</div>" |
| } |
|
|
| def create_app(): |
| """Create and configure the Gradio interface""" |
| |
| language_options = ["Auto-detect"] + list(SUPPORTED_LANGUAGES.values()) |
| |
| |
| with gr.Blocks(css=""" |
| .error { color: #ff6b6b; font-weight: bold; padding: 10px; border: 1px solid #ff6b6b; } |
| .container { margin: 0 auto; max-width: 900px; } |
| .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } |
| .example-text { font-style: italic; color: #666; } |
| """) as app: |
| gr.Markdown(""" |
| # Multilingual Toxic Comment Classifier |
| This app analyzes text for different types of toxicity across multiple languages. |
| Enter your text, select a language (or let it auto-detect), and click 'Analyze'. |
| |
| Supported languages: English, Russian, Turkish, Spanish, French, Italian, Portuguese |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| text_input = gr.Textbox( |
| label="Enter text to analyze", |
| placeholder="Type or paste text here...", |
| lines=5 |
| ) |
| lang_dropdown = gr.Dropdown( |
| choices=language_options, |
| value="Auto-detect", |
| label="Language" |
| ) |
| analyze_btn = gr.Button("Analyze", variant="primary") |
| |
| with gr.Column(scale=2): |
| gr.Markdown("### Example texts:") |
| with gr.Accordion("English example"): |
| en_example_btn = gr.Button("Use English example") |
| with gr.Accordion("Spanish example"): |
| es_example_btn = gr.Button("Use Spanish example") |
| with gr.Accordion("French example"): |
| fr_example_btn = gr.Button("Use French example") |
| |
| |
| en_example_text = "You are such an idiot, nobody likes your stupid content." |
| es_example_text = "Eres un completo idiota y nadie te quiere." |
| fr_example_text = "Tu es tellement stupide, personne n'aime ton contenu minable." |
| |
| en_example_btn.click( |
| lambda: en_example_text, |
| outputs=text_input |
| ) |
| es_example_btn.click( |
| lambda: es_example_text, |
| outputs=text_input |
| ) |
| fr_example_btn.click( |
| lambda: fr_example_text, |
| outputs=text_input |
| ) |
| |
| |
| result_html = gr.HTML(label="Analysis Result") |
| plot_output = gr.Plot(label="Toxicity Probabilities") |
| |
| |
| analyze_btn.click( |
| predict_toxicity, |
| inputs=[text_input, lang_dropdown], |
| outputs=[result_html, plot_output] |
| ) |
| |
| |
| text_input.submit( |
| predict_toxicity, |
| inputs=[text_input, lang_dropdown], |
| outputs=[result_html, plot_output] |
| ) |
| |
| gr.Markdown(""" |
| ### About this model |
| This model classifies text into six toxicity categories: |
| - **Toxic**: General toxicity |
| - **Severe Toxic**: Extreme toxicity |
| - **Obscene**: Obscene content |
| - **Threat**: Threatening content |
| - **Insult**: Insulting content |
| - **Identity Hate**: Identity-based hate |
| |
| Built using XLM-RoBERTa with language-aware fine-tuning. |
| """) |
| |
| return app |
|
|
| |
| if __name__ == "__main__": |
| |
| app = create_app() |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True |
| ) |