Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.naive_bayes import MultinomialNB | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, accuracy_score | |
| import re | |
| import os | |
| from typing import List, Dict, Any | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Email Attachment Classifier API", | |
| description="API to classify whether an email has attachments or not using Naive Bayes", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models | |
| class EmailInput(BaseModel): | |
| message: str | |
| class EmailBatchInput(BaseModel): | |
| messages: List[str] | |
| class PredictionResponse(BaseModel): | |
| message: str | |
| prediction: int | |
| prediction_label: str | |
| confidence: float | |
| probabilities: Dict[str, float] | |
| class BatchPredictionResponse(BaseModel): | |
| predictions: List[PredictionResponse] | |
| class ModelInfo(BaseModel): | |
| model_type: str | |
| accuracy: float | |
| feature_count: int | |
| training_samples: int | |
| # Global variables | |
| model_pipeline = None | |
| model_info = None | |
| def preprocess_text(text: str) -> str: | |
| """Preprocess email text""" | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| # Remove special characters but keep basic punctuation | |
| text = re.sub(r'[^\w\s,.\-!?]', ' ', text) | |
| return text.strip() | |
| def load_and_train_model(): | |
| """Load data and train the Naive Bayes model""" | |
| global model_pipeline, model_info | |
| try: | |
| # Load the dataset (assuming it's in the same directory) | |
| if os.path.exists('Synthetic_Email_Dataset.csv'): | |
| df = pd.read_csv('Synthetic_Email_Dataset.csv') | |
| else: | |
| logger.warning("Dataset file not found, creating sample data") | |
| # Create sample data for demonstration | |
| sample_data = { | |
| 'label': [0, 1, 0, 1] * 100, | |
| 'message': [ | |
| "Hello, You asked for it, so here is the notes. Warm wishes, David", | |
| "Good morning, Just sharing the meeting agenda as requested. Cheers, Anna", | |
| "Dear team, As discussed, I'm sending the manual. Regards, Emily", | |
| "Hi all, Please find attached the project plan. Thanks, Michael" | |
| ] * 100 | |
| } | |
| df = pd.DataFrame(sample_data) | |
| # Preprocess messages | |
| df['processed_message'] = df['message'].apply(preprocess_text) | |
| # Split data | |
| X = df['processed_message'] | |
| y = df['label'] | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42, stratify=y | |
| ) | |
| # Create pipeline | |
| model_pipeline = Pipeline([ | |
| ('tfidf', TfidfVectorizer( | |
| max_features=1000, | |
| ngram_range=(1, 2), | |
| stop_words='english', | |
| lowercase=True, | |
| min_df=1, | |
| max_df=0.95 | |
| )), | |
| ('classifier', MultinomialNB(alpha=1.0)) | |
| ]) | |
| # Train model | |
| logger.info("Training Naive Bayes model...") | |
| model_pipeline.fit(X_train, y_train) | |
| # Evaluate model | |
| y_pred = model_pipeline.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| # Store model info | |
| model_info = ModelInfo( | |
| model_type="Multinomial Naive Bayes", | |
| accuracy=round(accuracy, 4), | |
| feature_count=model_pipeline.named_steps['tfidf'].vocabulary_.__len__(), | |
| training_samples=len(X_train) | |
| ) | |
| logger.info(f"Model trained successfully with accuracy: {accuracy:.4f}") | |
| logger.info(f"Feature count: {model_info.feature_count}") | |
| # Save model | |
| joblib.dump(model_pipeline, 'email_classifier_model.pkl') | |
| logger.info("Model saved successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error in training model: {str(e)}") | |
| return False | |
| def load_pretrained_model(): | |
| """Load pretrained model if available""" | |
| global model_pipeline, model_info | |
| try: | |
| if os.path.exists('email_classifier_model.pkl'): | |
| model_pipeline = joblib.load('email_classifier_model.pkl') | |
| logger.info("Pretrained model loaded successfully") | |
| # Set default model info if not available | |
| if model_info is None: | |
| model_info = ModelInfo( | |
| model_type="Multinomial Naive Bayes", | |
| accuracy=0.92, # Default value | |
| feature_count=len(model_pipeline.named_steps['tfidf'].vocabulary_), | |
| training_samples=320 # Default value | |
| ) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading pretrained model: {str(e)}") | |
| return False | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| logger.info("Starting Email Classifier API...") | |
| # Try to load pretrained model first | |
| if not load_pretrained_model(): | |
| # If no pretrained model, train new one | |
| if not load_and_train_model(): | |
| logger.error("Failed to initialize model") | |
| async def root(): | |
| """Root endpoint with API documentation""" | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Email Attachment Classifier API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 40px; } | |
| .header { color: #2c3e50; } | |
| .endpoint { background-color: #f8f9fa; padding: 15px; margin: 10px 0; border-radius: 5px; } | |
| .method { color: #27ae60; font-weight: bold; } | |
| code { background-color: #e9ecef; padding: 2px 4px; border-radius: 3px; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1 class="header">📧 Email Attachment Classifier API</h1> | |
| <p>This API classifies whether an email message indicates an attachment or not using Naive Bayes classifier.</p> | |
| <h2>Available Endpoints:</h2> | |
| <div class="endpoint"> | |
| <h3><span class="method">GET</span> /info</h3> | |
| <p>Get model information and statistics</p> | |
| </div> | |
| <div class="endpoint"> | |
| <h3><span class="method">POST</span> /predict</h3> | |
| <p>Predict single email message</p> | |
| <p><strong>Body:</strong> <code>{"message": "Your email content here"}</code></p> | |
| </div> | |
| <div class="endpoint"> | |
| <h3><span class="method">POST</span> /predict-batch</h3> | |
| <p>Predict multiple email messages</p> | |
| <p><strong>Body:</strong> <code>{"messages": ["Email 1", "Email 2", ...]}</code></p> | |
| </div> | |
| <div class="endpoint"> | |
| <h3><span class="method">GET</span> /health</h3> | |
| <p>Check API health status</p> | |
| </div> | |
| <h2>Interactive Documentation:</h2> | |
| <p>Visit <a href="/docs">/docs</a> for Swagger UI or <a href="/redoc">/redoc</a> for ReDoc</p> | |
| <h2>Labels:</h2> | |
| <ul> | |
| <li><strong>0:</strong> No attachment mentioned</li> | |
| <li><strong>1:</strong> Attachment mentioned</li> | |
| </ul> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content, status_code=200) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| if model_pipeline is None: | |
| return {"status": "unhealthy", "message": "Model not loaded"} | |
| return {"status": "healthy", "message": "API is running"} | |
| async def get_model_info(): | |
| """Get model information""" | |
| if model_info is None: | |
| raise HTTPException(status_code=503, detail="Model not initialized") | |
| return model_info | |
| async def predict_single(email: EmailInput): | |
| """Predict single email message""" | |
| if model_pipeline is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Preprocess input | |
| processed_message = preprocess_text(email.message) | |
| # Make prediction | |
| prediction = model_pipeline.predict([processed_message])[0] | |
| probabilities = model_pipeline.predict_proba([processed_message])[0] | |
| # Prepare response | |
| prediction_label = "Has attachment" if prediction == 1 else "No attachment" | |
| confidence = float(max(probabilities)) | |
| prob_dict = { | |
| "no_attachment": float(probabilities[0]), | |
| "has_attachment": float(probabilities[1]) | |
| } | |
| return PredictionResponse( | |
| message=email.message, | |
| prediction=int(prediction), | |
| prediction_label=prediction_label, | |
| confidence=confidence, | |
| probabilities=prob_dict | |
| ) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def predict_batch(emails: EmailBatchInput): | |
| """Predict multiple email messages""" | |
| if model_pipeline is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if len(emails.messages) > 100: | |
| raise HTTPException(status_code=400, detail="Maximum 100 messages per batch") | |
| try: | |
| predictions = [] | |
| # Preprocess all messages | |
| processed_messages = [preprocess_text(msg) for msg in emails.messages] | |
| # Make batch predictions | |
| batch_predictions = model_pipeline.predict(processed_messages) | |
| batch_probabilities = model_pipeline.predict_proba(processed_messages) | |
| # Prepare responses | |
| for i, (message, prediction, probabilities) in enumerate( | |
| zip(emails.messages, batch_predictions, batch_probabilities) | |
| ): | |
| prediction_label = "Has attachment" if prediction == 1 else "No attachment" | |
| confidence = float(max(probabilities)) | |
| prob_dict = { | |
| "no_attachment": float(probabilities[0]), | |
| "has_attachment": float(probabilities[1]) | |
| } | |
| predictions.append(PredictionResponse( | |
| message=message, | |
| prediction=int(prediction), | |
| prediction_label=prediction_label, | |
| confidence=confidence, | |
| probabilities=prob_dict | |
| )) | |
| return BatchPredictionResponse(predictions=predictions) | |
| except Exception as e: | |
| logger.error(f"Batch prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |