import gradio as gr import torch from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification from setfit import SetFitModel import json import logging from typing import List, Dict, Any import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model variable model = None tokenizer = None classifier = None def load_model(): """Load your trained SetFit model""" global model, classifier try: model_name = "Tomiwajin/setfit_email_classifier" token = os.getenv("HF_TOKEN") model = SetFitModel.from_pretrained( model_name, use_auth_token=token if token else True ) # Create classifier directly from SetFit model logger.info(f"Model {model_name} loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {e}") return False def classify_single_email(email_text: str) -> Dict[str, Any]: """Classify a single email""" if not model: return {"error": "Model not loaded"} try: # Clean and truncate text email_text = email_text.strip()[:5000] # Limit length # Get prediction using SetFit model directly predictions = model.predict([email_text]) probabilities = model.predict_proba([email_text])[0] # Get probabilities for first (and only) sample # Get the predicted label and confidence predicted_label = predictions[0] confidence = max(probabilities) # Confidence is the max probability return { "label": str(predicted_label), "score": round(float(confidence), 4), "success": True } except Exception as e: logger.error(f"Classification error: {e}") return {"error": str(e), "success": False} def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]: """Classify multiple emails""" if not model: return [{"error": "Model not loaded"}] * len(emails) try: # Clean and truncate texts cleaned_emails = [email.strip()[:1000] for email in emails] # Get batch predictions predictions = model.predict(cleaned_emails) probabilities = model.predict_proba(cleaned_emails) results = [] for i, (pred, probs) in enumerate(zip(predictions, probabilities)): results.append({ "label": str(pred), "score": round(float(max(probs)), 4), "success": True }) return results except Exception as e: logger.error(f"Batch classification error: {e}") return [{"error": str(e), "success": False}] * len(emails) def gradio_classify(email_text: str) -> str: """Gradio interface function""" if not email_text.strip(): return "Please enter some email text to classify." result = classify_single_email(email_text) if result.get("success"): return f""" **Classification Result:** - **Label:** {result['label']} - **Confidence:** {result['score']:.2%} """ else: return f"**Error:** {result.get('error', 'Unknown error')}" def api_classify(email_text: str) -> Dict[str, Any]: """API endpoint function""" return classify_single_email(email_text) def api_classify_batch(emails_json: str) -> str: """Batch API endpoint function""" try: emails = json.loads(emails_json) if not isinstance(emails, list): return json.dumps({"error": "Input must be a JSON array of strings"}) if len(emails) > 100: # Limit batch size return json.dumps({"error": "Maximum 100 emails per batch"}) results = classify_batch_emails(emails) return json.dumps({"results": results}, indent=2) except json.JSONDecodeError: return json.dumps({"error": "Invalid JSON format"}) except Exception as e: return json.dumps({"error": str(e)}) # Load model on startup logger.info("Loading model...") model_loaded = load_model() if not model_loaded: logger.warning("Model failed to load - using dummy responses") def classify_single_email(email_text: str): return {"label": "applied", "score": 0.95, "success": True, "note": "Using dummy classifier"} # Create Gradio interface with gr.Blocks(title="Email Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📧 Email Classification API") gr.Markdown("Classify emails as job-related or other categories using a trained SetFit model.") with gr.Tab("Single Email Classification"): with gr.Row(): with gr.Column(): email_input = gr.Textbox( label="Email Content", placeholder="Paste your email content here (subject + body)...", lines=8, max_lines=20 ) classify_btn = gr.Button("Classify Email", variant="primary") with gr.Column(): result_output = gr.Markdown(label="Classification Result") classify_btn.click( fn=gradio_classify, inputs=email_input, outputs=result_output ) with gr.Tab("API Endpoints"): gr.Markdown(""" ## API Usage ### Single Email Classification **POST** `/api/classify` ```json { "email_text": "Your email content here..." } ``` ### Batch Email Classification **POST** `/api/classify_batch` ```json ["Email 1 content...", "Email 2 content...", "Email 3 content..."] ``` ### Example Response ```json { "label": "job", "score": 0.9234, "success": true } ``` """) with gr.Row(): with gr.Column(): gr.Markdown("### Test Single API") api_input = gr.Textbox(label="Email Text", lines=4) api_btn = gr.Button("Test API") api_output = gr.JSON(label="API Response") api_btn.click( fn=api_classify, inputs=api_input, outputs=api_output ) with gr.Column(): gr.Markdown("### Test Batch API") batch_input = gr.Textbox( label="JSON Array of Emails", lines=6, placeholder='["Email 1 content", "Email 2 content"]' ) batch_btn = gr.Button("Test Batch API") batch_output = gr.Code(label="Batch API Response", language="json") batch_btn.click( fn=api_classify_batch, inputs=batch_input, outputs=batch_output ) with gr.Tab("Model Info"): gr.Markdown(f""" ### Model Information - **Status:** {'✅ Loaded' if model_loaded else '❌ Failed to load'} - **Model Type:** SetFit Email Classifier - **Categories:** Job-related emails, Other emails - **API Base URL:** `https://tomiwajin-email-classifier.hf.space` ### Integration with Next.js ```javascript // Single email classification const response = await fetch('https://tomiwajin-email-classifier.hf.space/api/classify', {{ method: 'POST', headers: {{ 'Content-Type': 'application/json' }}, body: JSON.stringify({{ email_text: emailContent }}) }}); const result = await response.json(); // Batch classification const batchResponse = await fetch('https:https://tomiwajin-email-classifier.hf.space/api/classify_batch', {{ method: 'POST', headers: {{ 'Content-Type': 'application/json' }}, body: JSON.stringify(emailArray) }}); const batchResults = await batchResponse.json(); ``` """) # Launch the app with API endpoints if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_api=True, share=False )