Spaces:
Sleeping
Sleeping
| """ | |
| WASH CFM Topic Classification Gradio Application | |
| This application provides a user interface for classifying WASH (Water, Sanitation, | |
| and Hygiene) feedback using a fine-tuned ModernBERT model. | |
| This is a Gradio implementation with identical functionality to wash_cfm_app.py. | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import functools | |
| import os | |
| import tempfile | |
| # ================================ | |
| # CONFIGURATION SECTION | |
| # ================================ | |
| # Replace these with your actual Hugging Face repository details | |
| HF_REPO_ID = "ibagur/wash_cfm_classifier" # Your Hugging Face repository | |
| HF_MODEL_CACHE_DIR = "/tmp/model_cache" # Cache directory (using /tmp for better Space compatibility) | |
| # ================================ | |
| def load_model(): | |
| """ | |
| Load the pre-trained WASH CFM classifier model from Hugging Face Hub and create a pipeline. | |
| Downloads the model at runtime if not already cached locally. | |
| Uses LRU cache to avoid reloading on every interaction. | |
| Returns: | |
| pipeline: Hugging Face transformers pipeline for text classification | |
| """ | |
| print(f"Downloading model from Hugging Face Hub: {HF_REPO_ID}") | |
| print("This may take a few minutes on first run...") | |
| try: | |
| # Download the entire model repository to cache | |
| # This is more efficient than downloading individual files | |
| model_path = snapshot_download( | |
| repo_id=HF_REPO_ID, | |
| cache_dir=HF_MODEL_CACHE_DIR, | |
| resume_download=True, # Resume if download was interrupted | |
| local_files_only=False # Force download if not in cache | |
| ) | |
| print(f"Model downloaded successfully to: {model_path}") | |
| # Load tokenizer and model from the downloaded path | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| # Set to evaluation mode | |
| model.eval() | |
| # Check what device we're using (including Apple Silicon MPS support) | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") # Apple Silicon | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") # NVIDIA GPU | |
| else: | |
| device = torch.device("cpu") # CPU fallback | |
| print(f"Using device: {device}") | |
| model.to(device) | |
| # Create pipeline for easy inference | |
| classifier = pipeline( | |
| 'text-classification', | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| return classifier | |
| except Exception as e: | |
| print(f"Error downloading model: {str(e)}") | |
| print("\nTroubleshooting steps:") | |
| print("1. Check that your repository ID is correct") | |
| print("2. Ensure the repository is public or you have proper access") | |
| print("3. Check your internet connection") | |
| print("4. Verify the repository exists on Hugging Face Hub") | |
| raise | |
| def predict_topics(text, classifier, top_k=2): | |
| """ | |
| Predict the top-k most probable topics for the given text using the pipeline. | |
| Args: | |
| text (str): Input feedback text | |
| classifier: Hugging Face transformers pipeline | |
| top_k (int): Number of top predictions to return | |
| Returns: | |
| list: List of tuples (topic_name, probability) | |
| """ | |
| # Use pipeline for prediction - it handles all the complexity internally | |
| predictions = classifier(text, top_k=top_k) | |
| # Convert pipeline results to our format | |
| results = [(pred['label'], pred['score']) for pred in predictions] | |
| return results | |
| def classify_feedback(text): | |
| """ | |
| Main classification handler for Gradio interface. | |
| Args: | |
| text (str): Input WASH feedback text | |
| Returns: | |
| str: HTML formatted prediction results | |
| """ | |
| # Validate input | |
| if not text or not text.strip(): | |
| return """ | |
| <div style=" | |
| background-color: #fff3cd; | |
| color: #856404; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border-left: 4px solid #ffc107; | |
| font-weight: 500; | |
| "> | |
| β οΈ Please enter some feedback text. | |
| </div> | |
| """ | |
| try: | |
| # Load classifier pipeline (cached) | |
| classifier = load_model() | |
| # Get predictions | |
| predictions = predict_topics( | |
| text, | |
| classifier, | |
| top_k=2 | |
| ) | |
| # Format results as HTML | |
| html_output = """ | |
| <div style="margin-top: 10px;"> | |
| <h3 style="color: #333; margin-bottom: 15px;">π Predicted Topics</h3> | |
| """ | |
| for i, (topic, probability) in enumerate(predictions, 1): | |
| # Add prediction box with fixed color and enhanced specificity | |
| html_output += f""" | |
| <div style=" | |
| background-color: #009999 !important; | |
| color: #ffffff !important; | |
| padding: 15px; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| font-weight: 500; | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| "> | |
| <div style=" | |
| font-size: 16px; | |
| margin-bottom: 5px; | |
| color: #ffffff !important; | |
| font-weight: 600; | |
| "> | |
| {i}. {topic} | |
| </div> | |
| <div style=" | |
| font-size: 14px; | |
| opacity: 0.9; | |
| color: #ffffff !important; | |
| "> | |
| Confidence: {probability:.1%} | |
| </div> | |
| </div> | |
| """ | |
| html_output += "</div>" | |
| return html_output | |
| except FileNotFoundError: | |
| return """ | |
| <div style=" | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border-left: 4px solid #dc3545; | |
| "> | |
| <strong>β Error loading model</strong><br> | |
| Could not download or access the model from Hugging Face Hub.<br> | |
| Please check your internet connection and repository configuration. | |
| </div> | |
| """ | |
| except Exception as e: | |
| return f""" | |
| <div style=" | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border-left: 4px solid #dc3545; | |
| "> | |
| <strong>β Error during prediction:</strong><br> | |
| {str(e)} | |
| </div> | |
| """ | |
| def clear_inputs(): | |
| """ | |
| Clear both input and output fields. | |
| Returns: | |
| tuple: Empty strings for textbox and output | |
| """ | |
| return "", "" | |
| def create_interface(): | |
| """ | |
| Create and configure the Gradio interface. | |
| Returns: | |
| gr.Blocks: Configured Gradio interface | |
| """ | |
| with gr.Blocks( | |
| title="WASH CFM Topic Classifier", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # π§ WASH CFM Topic Classifier | |
| This application classifies WASH (Water, Sanitation, and Hygiene) feedback | |
| into relevant topic categories using a fine-tuned ModernBERT model. | |
| **Enter your feedback below and click Submit.** | |
| """) | |
| # Input section | |
| input_textbox = gr.Textbox( | |
| label="Enter WASH feedback:", | |
| placeholder="Example: The water pump in our area has been broken for 3 days...", | |
| lines=6, | |
| interactive=True | |
| ) | |
| # Button row | |
| with gr.Row(): | |
| submit_btn = gr.Button("β Submit", variant="primary", scale=2) | |
| clear_btn = gr.Button("ποΈ Clear", scale=1) | |
| # Output section | |
| output_html = gr.HTML(label="Results") | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; color: #666; font-size: 12px;"> | |
| Powered by ModernBERT-large | UNICEF WASH Cluster CFM System | |
| </div> | |
| """) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=classify_feedback, | |
| inputs=input_textbox, | |
| outputs=output_html | |
| ) | |
| input_textbox.submit( | |
| fn=classify_feedback, | |
| inputs=input_textbox, | |
| outputs=output_html | |
| ) | |
| clear_btn.click( | |
| fn=clear_inputs, | |
| inputs=None, | |
| outputs=[input_textbox, output_html] | |
| ) | |
| return demo | |
| def main(): | |
| """ | |
| Main function to launch the Gradio application. | |
| """ | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |
| if __name__ == "__main__": | |
| main() | |