Tomiwajin commited on
Commit
82d59e5
·
verified ·
1 Parent(s): a2be272

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - HuggingFace Space for Email Classification
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
5
+ from setfit import SetFitModel
6
+ import json
7
+ import logging
8
+ from typing import List, Dict, Any
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Global model variable
15
+ model = None
16
+ tokenizer = None
17
+ classifier = None
18
+
19
+ def load_model():
20
+ """Load your trained SetFit model"""
21
+ global model, classifier
22
+ try:
23
+ # Replace with your actual model path/name
24
+ model_name = "Tomiwajin/setfit_email_classifier"
25
+
26
+ # For SetFit models
27
+ model = SetFitModel.from_pretrained(model_name)
28
+ classifier = pipeline("text-classification", model=model.model_head, tokenizer=model.model_body.tokenizer)
29
+
30
+
31
+ logger.info(f"Model {model_name} loaded successfully!")
32
+ return True
33
+ except Exception as e:
34
+ logger.error(f"Error loading model: {e}")
35
+ return False
36
+
37
+ def classify_single_email(email_text: str) -> Dict[str, Any]:
38
+ """Classify a single email"""
39
+ if not classifier:
40
+ return {"error": "Model not loaded"}
41
+
42
+ try:
43
+ # Clean and truncate text
44
+ email_text = email_text.strip()[:5000] # Limit length
45
+
46
+ # Get prediction
47
+ result = classifier(email_text)
48
+
49
+ if isinstance(result, list):
50
+ result = result[0]
51
+
52
+ return {
53
+ "label": result.get("label", "unknown"),
54
+ "score": round(result.get("score", 0.0), 4),
55
+ "success": True
56
+ }
57
+ except Exception as e:
58
+ logger.error(f"Classification error: {e}")
59
+ return {"error": str(e), "success": False}
60
+
61
+ def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]:
62
+ """Classify multiple emails"""
63
+ if not classifier:
64
+ return [{"error": "Model not loaded"}] * len(emails)
65
+
66
+ results = []
67
+ for email_text in emails:
68
+ result = classify_single_email(email_text)
69
+ results.append(result)
70
+
71
+ return results
72
+
73
+ def gradio_classify(email_text: str) -> str:
74
+ """Gradio interface function"""
75
+ if not email_text.strip():
76
+ return "Please enter some email text to classify."
77
+
78
+ result = classify_single_email(email_text)
79
+
80
+ if result.get("success"):
81
+ return f"""
82
+ **Classification Result:**
83
+ - **Label:** {result['label']}
84
+ - **Confidence:** {result['score']:.2%}
85
+ """
86
+ else:
87
+ return f"**Error:** {result.get('error', 'Unknown error')}"
88
+
89
+ def api_classify(email_text: str) -> Dict[str, Any]:
90
+ """API endpoint function"""
91
+ return classify_single_email(email_text)
92
+
93
+ def api_classify_batch(emails_json: str) -> str:
94
+ """Batch API endpoint function"""
95
+ try:
96
+ emails = json.loads(emails_json)
97
+ if not isinstance(emails, list):
98
+ return json.dumps({"error": "Input must be a JSON array of strings"})
99
+
100
+ results = classify_batch_emails(emails)
101
+ return json.dumps(results, indent=2)
102
+ except json.JSONDecodeError:
103
+ return json.dumps({"error": "Invalid JSON format"})
104
+ except Exception as e:
105
+ return json.dumps({"error": str(e)})
106
+
107
+ # Load model on startup
108
+ logger.info("Loading model...")
109
+ model_loaded = load_model()
110
+
111
+ if not model_loaded:
112
+ logger.warning("Model failed to load - using dummy responses")
113
+ def classify_single_email(email_text: str):
114
+ return {"label": "job", "score": 0.95, "success": True, "note": "Using dummy classifier"}
115
+
116
+ # Create Gradio interface
117
+ with gr.Blocks(title="Email Classifier", theme=gr.themes.Soft()) as demo:
118
+ gr.Markdown("# 📧 Email Classification API")
119
+ gr.Markdown("Classify emails as job-related or other categories using a trained SetFit model.")
120
+
121
+ with gr.Tab("Single Email Classification"):
122
+ with gr.Row():
123
+ with gr.Column():
124
+ email_input = gr.Textbox(
125
+ label="Email Content",
126
+ placeholder="Paste your email content here (subject + body)...",
127
+ lines=8,
128
+ max_lines=20
129
+ )
130
+ classify_btn = gr.Button("Classify Email", variant="primary")
131
+
132
+ with gr.Column():
133
+ result_output = gr.Markdown(label="Classification Result")
134
+
135
+ classify_btn.click(
136
+ fn=gradio_classify,
137
+ inputs=email_input,
138
+ outputs=result_output
139
+ )
140
+
141
+ with gr.Tab("API Endpoints"):
142
+ gr.Markdown("""
143
+ ## API Usage
144
+
145
+ ### Single Email Classification
146
+ **POST** `/api/classify`
147
+ ```json
148
+ {
149
+ "email_text": "Your email content here..."
150
+ }
151
+ ```
152
+
153
+ ### Batch Email Classification
154
+ **POST** `/api/classify-batch`
155
+ ```json
156
+ {
157
+ "emails": [
158
+ "Email 1 content...",
159
+ "Email 2 content...",
160
+ "Email 3 content..."
161
+ ]
162
+ }
163
+ ```
164
+
165
+ ### Example Response
166
+ ```json
167
+ {
168
+ "label": "job",
169
+ "score": 0.9234,
170
+ "success": true
171
+ }
172
+ ```
173
+ """)
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ gr.Markdown("### Test Single API")
178
+ api_input = gr.Textbox(label="Email Text", lines=4)
179
+ api_btn = gr.Button("Test API")
180
+ api_output = gr.JSON(label="API Response")
181
+
182
+ api_btn.click(
183
+ fn=api_classify,
184
+ inputs=api_input,
185
+ outputs=api_output
186
+ )
187
+
188
+ with gr.Column():
189
+ gr.Markdown("### Test Batch API")
190
+ batch_input = gr.Textbox(
191
+ label="JSON Array of Emails",
192
+ lines=6,
193
+ placeholder='["Email 1 content", "Email 2 content"]'
194
+ )
195
+ batch_btn = gr.Button("Test Batch API")
196
+ batch_output = gr.Code(label="Batch API Response", language="json")
197
+
198
+ batch_btn.click(
199
+ fn=api_classify_batch,
200
+ inputs=batch_input,
201
+ outputs=batch_output
202
+ )
203
+
204
+ with gr.Tab("Model Info"):
205
+ gr.Markdown(f"""
206
+ ### Model Information
207
+ - **Status:** {'✅ Loaded' if model_loaded else '❌ Failed to load'}
208
+ - **Model Type:** SetFit Email Classifier
209
+ - **Categories:** Job-related emails, Other emails
210
+ - **API Base URL:** `https://your-space-name.hf.space`
211
+
212
+ ### Integration with Next.js
213
+ ```javascript
214
+ // Single email classification
215
+ const response = await fetch('https://your-space-name.hf.space/api/classify', {{
216
+ method: 'POST',
217
+ headers: {{ 'Content-Type': 'application/json' }},
218
+ body: JSON.stringify({{ email_text: emailContent }})
219
+ }});
220
+ const result = await response.json();
221
+
222
+ // Batch classification
223
+ const batchResponse = await fetch('https://your-space-name.hf.space/api/classify-batch', {{
224
+ method: 'POST',
225
+ headers: {{ 'Content-Type': 'application/json' }},
226
+ body: JSON.stringify({{ emails: emailArray }})
227
+ }});
228
+ const batchResults = await batchResponse.json();
229
+ ```
230
+ """)
231
+
232
+ # Set up API endpoints
233
+ def setup_api_routes(app):
234
+ """Setup FastAPI routes for the Gradio app"""
235
+ from fastapi import FastAPI, HTTPException
236
+ from pydantic import BaseModel
237
+
238
+ class EmailRequest(BaseModel):
239
+ email_text: str
240
+
241
+ class BatchEmailRequest(BaseModel):
242
+ emails: List[str]
243
+
244
+ @app.post("/api/classify")
245
+ async def classify_endpoint(request: EmailRequest):
246
+ result = classify_single_email(request.email_text)
247
+ if not result.get("success", True):
248
+ raise HTTPException(status_code=500, detail=result.get("error", "Classification failed"))
249
+ return result
250
+
251
+ @app.post("/api/classify-batch")
252
+ async def classify_batch_endpoint(request: BatchEmailRequest):
253
+ if len(request.emails) > 100: # Limit batch size
254
+ raise HTTPException(status_code=400, detail="Maximum 100 emails per batch")
255
+
256
+ results = classify_batch_emails(request.emails)
257
+ return {"results": results}
258
+
259
+ # Launch the app
260
+ if __name__ == "__main__":
261
+ # Setup API routes
262
+ setup_api_routes(demo.fastapi_app)
263
+
264
+ # Launch with API support
265
+ demo.launch(
266
+ server_name="0.0.0.0",
267
+ server_port=7860,
268
+ show_api=True,
269
+ share=False
270
+ )