Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| from setfit import SetFitModel | |
| import json | |
| import logging | |
| from typing import List, Dict, Any | |
| import os | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model variable | |
| model = None | |
| tokenizer = None | |
| classifier = None | |
| def load_model(): | |
| """Load your trained SetFit model""" | |
| global model, classifier | |
| try: | |
| model_name = "Tomiwajin/setfit_email_classifier" | |
| token = os.getenv("HF_TOKEN") | |
| model = SetFitModel.from_pretrained( | |
| model_name, | |
| use_auth_token=token if token else True | |
| ) | |
| # Create classifier directly from SetFit model | |
| logger.info(f"Model {model_name} loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| return False | |
| def classify_single_email(email_text: str) -> Dict[str, Any]: | |
| """Classify a single email""" | |
| if not model: | |
| return {"error": "Model not loaded"} | |
| try: | |
| # Clean and truncate text | |
| email_text = email_text.strip()[:5000] # Limit length | |
| # Get prediction using SetFit model directly | |
| predictions = model.predict([email_text]) | |
| probabilities = model.predict_proba([email_text])[0] # Get probabilities for first (and only) sample | |
| # Get the predicted label and confidence | |
| predicted_label = predictions[0] | |
| confidence = max(probabilities) # Confidence is the max probability | |
| return { | |
| "label": str(predicted_label), | |
| "score": round(float(confidence), 4), | |
| "success": True | |
| } | |
| except Exception as e: | |
| logger.error(f"Classification error: {e}") | |
| return {"error": str(e), "success": False} | |
| def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]: | |
| """Classify multiple emails""" | |
| if not model: | |
| return [{"error": "Model not loaded"}] * len(emails) | |
| try: | |
| # Clean and truncate texts | |
| cleaned_emails = [email.strip()[:1000] for email in emails] | |
| # Get batch predictions | |
| predictions = model.predict(cleaned_emails) | |
| probabilities = model.predict_proba(cleaned_emails) | |
| results = [] | |
| for i, (pred, probs) in enumerate(zip(predictions, probabilities)): | |
| results.append({ | |
| "label": str(pred), | |
| "score": round(float(max(probs)), 4), | |
| "success": True | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Batch classification error: {e}") | |
| return [{"error": str(e), "success": False}] * len(emails) | |
| def gradio_classify(email_text: str) -> str: | |
| """Gradio interface function""" | |
| if not email_text.strip(): | |
| return "Please enter some email text to classify." | |
| result = classify_single_email(email_text) | |
| if result.get("success"): | |
| return f""" | |
| **Classification Result:** | |
| - **Label:** {result['label']} | |
| - **Confidence:** {result['score']:.2%} | |
| """ | |
| else: | |
| return f"**Error:** {result.get('error', 'Unknown error')}" | |
| def api_classify(email_text: str) -> Dict[str, Any]: | |
| """API endpoint function""" | |
| return classify_single_email(email_text) | |
| def api_classify_batch(emails_json: str) -> str: | |
| """Batch API endpoint function""" | |
| try: | |
| emails = json.loads(emails_json) | |
| if not isinstance(emails, list): | |
| return json.dumps({"error": "Input must be a JSON array of strings"}) | |
| if len(emails) > 100: # Limit batch size | |
| return json.dumps({"error": "Maximum 100 emails per batch"}) | |
| results = classify_batch_emails(emails) | |
| return json.dumps({"results": results}, indent=2) | |
| except json.JSONDecodeError: | |
| return json.dumps({"error": "Invalid JSON format"}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| # Load model on startup | |
| logger.info("Loading model...") | |
| model_loaded = load_model() | |
| if not model_loaded: | |
| logger.warning("Model failed to load - using dummy responses") | |
| def classify_single_email(email_text: str): | |
| return {"label": "applied", "score": 0.95, "success": True, "note": "Using dummy classifier"} | |
| # Create Gradio interface | |
| with gr.Blocks(title="Email Classifier", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 📧 Email Classification API") | |
| gr.Markdown("Classify emails as job-related or other categories using a trained SetFit model.") | |
| with gr.Tab("Single Email Classification"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| email_input = gr.Textbox( | |
| label="Email Content", | |
| placeholder="Paste your email content here (subject + body)...", | |
| lines=8, | |
| max_lines=20 | |
| ) | |
| classify_btn = gr.Button("Classify Email", variant="primary") | |
| with gr.Column(): | |
| result_output = gr.Markdown(label="Classification Result") | |
| classify_btn.click( | |
| fn=gradio_classify, | |
| inputs=email_input, | |
| outputs=result_output | |
| ) | |
| with gr.Tab("API Endpoints"): | |
| gr.Markdown(""" | |
| ## API Usage | |
| ### Single Email Classification | |
| **POST** `/api/classify` | |
| ```json | |
| { | |
| "email_text": "Your email content here..." | |
| } | |
| ``` | |
| ### Batch Email Classification | |
| **POST** `/api/classify_batch` | |
| ```json | |
| ["Email 1 content...", "Email 2 content...", "Email 3 content..."] | |
| ``` | |
| ### Example Response | |
| ```json | |
| { | |
| "label": "job", | |
| "score": 0.9234, | |
| "success": true | |
| } | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Test Single API") | |
| api_input = gr.Textbox(label="Email Text", lines=4) | |
| api_btn = gr.Button("Test API") | |
| api_output = gr.JSON(label="API Response") | |
| api_btn.click( | |
| fn=api_classify, | |
| inputs=api_input, | |
| outputs=api_output | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Test Batch API") | |
| batch_input = gr.Textbox( | |
| label="JSON Array of Emails", | |
| lines=6, | |
| placeholder='["Email 1 content", "Email 2 content"]' | |
| ) | |
| batch_btn = gr.Button("Test Batch API") | |
| batch_output = gr.Code(label="Batch API Response", language="json") | |
| batch_btn.click( | |
| fn=api_classify_batch, | |
| inputs=batch_input, | |
| outputs=batch_output | |
| ) | |
| with gr.Tab("Model Info"): | |
| gr.Markdown(f""" | |
| ### Model Information | |
| - **Status:** {'✅ Loaded' if model_loaded else '❌ Failed to load'} | |
| - **Model Type:** SetFit Email Classifier | |
| - **Categories:** Job-related emails, Other emails | |
| - **API Base URL:** `https://tomiwajin-email-classifier.hf.space` | |
| ### Integration with Next.js | |
| ```javascript | |
| // Single email classification | |
| const response = await fetch('https://tomiwajin-email-classifier.hf.space/api/classify', {{ | |
| method: 'POST', | |
| headers: {{ 'Content-Type': 'application/json' }}, | |
| body: JSON.stringify({{ email_text: emailContent }}) | |
| }}); | |
| const result = await response.json(); | |
| // Batch classification | |
| const batchResponse = await fetch('https:https://tomiwajin-email-classifier.hf.space/api/classify_batch', {{ | |
| method: 'POST', | |
| headers: {{ 'Content-Type': 'application/json' }}, | |
| body: JSON.stringify(emailArray) | |
| }}); | |
| const batchResults = await batchResponse.json(); | |
| ``` | |
| """) | |
| # Launch the app with API endpoints | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=True, | |
| share=False | |
| ) |