Tomiwajin commited on
Commit
a25b8b3
·
verified ·
1 Parent(s): 3c1d5f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -527
app.py CHANGED
@@ -1,527 +1,236 @@
1
- import gradio as gr
2
- import torch
3
- from setfit import SetFitModel
4
- from transformers import AutoTokenizer, T5ForConditionalGeneration
5
- import json
6
- import logging
7
- import re
8
- from typing import List, Dict, Any
9
- import os
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- # Global model variables
15
- classifier_model = None
16
- extractor_model = None
17
- extractor_tokenizer = None
18
- device = None
19
-
20
- def load_models():
21
- """Load both classification and extraction models"""
22
- global classifier_model, extractor_model, extractor_tokenizer, device
23
-
24
- # Set device (GPU if available)
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- logger.info(f"🖥️ Using device: {device}")
27
-
28
- try:
29
- # Load classifier
30
- classifier_name = "Tomiwajin/testClasifier"
31
- token = os.getenv("HF_TOKEN")
32
-
33
- classifier_model = SetFitModel.from_pretrained(
34
- classifier_name,
35
- use_auth_token=token if token else False
36
- )
37
- logger.info(f"✅ Classifier loaded: {classifier_name}")
38
-
39
- # Load extractor
40
- extractor_name = "Tomiwajin/email-company-role-extractor"
41
- extractor_tokenizer = AutoTokenizer.from_pretrained(extractor_name)
42
- extractor_model = T5ForConditionalGeneration.from_pretrained(extractor_name)
43
- extractor_model.to(device)
44
- extractor_model.eval()
45
- logger.info(f"✅ Extractor loaded: {extractor_name}")
46
-
47
- return True
48
- except Exception as e:
49
- logger.error(f"❌ Model loading failed: {e}")
50
- return False
51
-
52
- def classify_single_email(email_text: str) -> Dict[str, Any]:
53
- """Classify a single email"""
54
- if not classifier_model:
55
- return {"error": "Classifier not loaded", "success": False}
56
-
57
- try:
58
- email_text = email_text.strip()[:1000]
59
- predictions = classifier_model.predict([email_text])
60
- probabilities = classifier_model.predict_proba([email_text])[0]
61
-
62
- predicted_label = predictions[0]
63
- confidence = max(probabilities)
64
-
65
- return {
66
- "label": str(predicted_label),
67
- "score": round(float(confidence), 4),
68
- "success": True
69
- }
70
- except Exception as e:
71
- logger.error(f"Classification error: {e}")
72
- return {"error": str(e), "success": False}
73
-
74
- def extract_job_info(email_text: str) -> Dict[str, Any]:
75
- """Extract company and role from email"""
76
- if not extractor_model or not extractor_tokenizer:
77
- return {"error": "Extractor not loaded", "success": False}
78
-
79
- try:
80
- email_text = email_text.strip()[:1000]
81
- input_text = f"extract company and role: {email_text}"
82
-
83
- inputs = extractor_tokenizer(
84
- input_text,
85
- return_tensors='pt',
86
- max_length=512,
87
- truncation=True
88
- ).to(device)
89
-
90
- with torch.no_grad():
91
- outputs = extractor_model.generate(
92
- inputs.input_ids,
93
- attention_mask=inputs.attention_mask,
94
- max_length=128,
95
- num_beams=2,
96
- early_stopping=True,
97
- pad_token_id=extractor_tokenizer.pad_token_id
98
- )
99
-
100
- prediction = extractor_tokenizer.decode(outputs[0], skip_special_tokens=True)
101
- result = parse_extraction_result(prediction)
102
- return result
103
- except Exception as e:
104
- logger.error(f"Extraction error: {e}")
105
- return {
106
- "company": "unknown",
107
- "role": "unknown",
108
- "success": False,
109
- "error": str(e)
110
- }
111
-
112
- def parse_extraction_result(prediction: str) -> Dict[str, Any]:
113
- """Parse the model output into structured result"""
114
- try:
115
- fixed = prediction.strip()
116
- if fixed.startswith('"') and not fixed.startswith('{'):
117
- fixed = '{' + fixed
118
- if not fixed.endswith('}'):
119
- fixed = fixed + '}'
120
- fixed = re.sub(r'",(\s*)"', '", "', fixed)
121
-
122
- result = json.loads(fixed)
123
- return {
124
- "company": result.get("company", "unknown"),
125
- "role": result.get("role", "unknown"),
126
- "success": True
127
- }
128
- except:
129
- return {
130
- "company": "unknown",
131
- "role": "unknown",
132
- "success": False
133
- }
134
-
135
- def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]:
136
- """Classify multiple emails - already batched"""
137
- if not classifier_model:
138
- return [{"error": "Model not loaded", "success": False}] * len(emails)
139
-
140
- try:
141
- cleaned_emails = [email.strip()[:1000] for email in emails]
142
- predictions = classifier_model.predict(cleaned_emails)
143
- probabilities = classifier_model.predict_proba(cleaned_emails)
144
-
145
- results = []
146
- for pred, probs in zip(predictions, probabilities):
147
- results.append({
148
- "label": str(pred),
149
- "score": round(float(max(probs)), 4),
150
- "success": True
151
- })
152
-
153
- return results
154
- except Exception as e:
155
- logger.error(f"Batch classification error: {e}")
156
- return [{"error": str(e), "success": False}] * len(emails)
157
-
158
- def extract_batch(emails: List[str]) -> List[Dict[str, Any]]:
159
- """Extract company and role from multiple emails - BATCHED for performance"""
160
- if not extractor_model or not extractor_tokenizer:
161
- return [{"error": "Extractor not loaded", "success": False}] * len(emails)
162
-
163
- if len(emails) == 0:
164
- return []
165
-
166
- try:
167
- # Prepare all inputs at once
168
- cleaned_emails = [email.strip()[:1000] for email in emails]
169
- input_texts = [f"extract company and role: {email}" for email in cleaned_emails]
170
-
171
- # Batch tokenize
172
- inputs = extractor_tokenizer(
173
- input_texts,
174
- return_tensors='pt',
175
- max_length=512,
176
- truncation=True,
177
- padding=True
178
- ).to(device)
179
-
180
- # Batch generate - process all at once
181
- with torch.no_grad():
182
- outputs = extractor_model.generate(
183
- inputs.input_ids,
184
- attention_mask=inputs.attention_mask,
185
- max_length=128,
186
- num_beams=2, # Reduced from 4 for speed
187
- early_stopping=True,
188
- pad_token_id=extractor_tokenizer.pad_token_id
189
- )
190
-
191
- # Decode all at once
192
- predictions = extractor_tokenizer.batch_decode(outputs, skip_special_tokens=True)
193
-
194
- # Parse results
195
- results = [parse_extraction_result(pred) for pred in predictions]
196
- return results
197
-
198
- except Exception as e:
199
- logger.error(f"Batch extraction error: {e}")
200
- return [{"company": "unknown", "role": "unknown", "success": False, "error": str(e)}] * len(emails)
201
-
202
- def process_batch(emails: List[str], job_labels: List[str] = None, threshold: float = 0.5) -> Dict[str, Any]:
203
- """
204
- Combined endpoint: classify emails and extract job info in ONE call.
205
- Only extracts from emails classified as job-related.
206
- """
207
- if job_labels is None:
208
- job_labels = ["applied", "rejected", "interview", "next-phase", "offer"]
209
-
210
- # Step 1: Classify all emails
211
- classifications = classify_batch_emails(emails)
212
-
213
- # Step 2: Find job-related emails
214
- job_indices = []
215
- job_emails = []
216
- for i, (email, cls) in enumerate(zip(emails, classifications)):
217
- if cls.get("success") and cls.get("label", "").lower() in job_labels and cls.get("score", 0) >= threshold:
218
- job_indices.append(i)
219
- job_emails.append(email)
220
-
221
- # Step 3: Extract only from job-related emails (batched)
222
- extractions = extract_batch(job_emails) if job_emails else []
223
-
224
- # Step 4: Combine results
225
- results = []
226
- extraction_idx = 0
227
- for i, cls in enumerate(classifications):
228
- result = {
229
- "classification": cls,
230
- "extraction": None
231
- }
232
- if i in job_indices:
233
- result["extraction"] = extractions[extraction_idx]
234
- extraction_idx += 1
235
- results.append(result)
236
-
237
- return {
238
- "results": results,
239
- "total": len(emails),
240
- "job_related": len(job_emails)
241
- }
242
-
243
- def gradio_classify(email_text: str) -> str:
244
- """Gradio interface for classification"""
245
- if not email_text.strip():
246
- return "Please enter some email text to classify."
247
-
248
- result = classify_single_email(email_text)
249
-
250
- if result.get("success"):
251
- return f"""
252
- **Classification Result:**
253
- - **Label:** {result['label']}
254
- - **Confidence:** {result['score']:.2%}
255
- """
256
- else:
257
- return f"**Error:** {result.get('error', 'Unknown error')}"
258
-
259
- def gradio_extract(email_text: str) -> str:
260
- """Gradio interface for extraction"""
261
- if not email_text.strip():
262
- return "Please enter some email text to extract from."
263
-
264
- result = extract_job_info(email_text)
265
-
266
- if result.get("success"):
267
- return f"""
268
- **Extraction Result:**
269
- - **Company:** {result['company']}
270
- - **Role:** {result['role']}
271
- """
272
- else:
273
- return f"**Error:** {result.get('error', 'Unknown error')}"
274
-
275
- def api_classify(email_text: str) -> Dict[str, Any]:
276
- """API endpoint for single classification"""
277
- return classify_single_email(email_text)
278
-
279
- def api_extract(email_text: str) -> Dict[str, Any]:
280
- """API endpoint for single extraction"""
281
- return extract_job_info(email_text)
282
-
283
- def api_classify_batch(emails_json: str) -> str:
284
- """API endpoint for batch classification"""
285
- try:
286
- emails = json.loads(emails_json)
287
- if not isinstance(emails, list):
288
- return json.dumps({"error": "Input must be a JSON array of strings"})
289
-
290
- if len(emails) > 400:
291
- return json.dumps({"error": "Maximum 400 emails per batch"})
292
-
293
- results = classify_batch_emails(emails)
294
- return json.dumps({"results": results})
295
- except json.JSONDecodeError:
296
- return json.dumps({"error": "Invalid JSON format"})
297
- except Exception as e:
298
- return json.dumps({"error": str(e)})
299
-
300
- def api_extract_batch(emails_json: str) -> str:
301
- """API endpoint for batch extraction - NOW BATCHED"""
302
- try:
303
- emails = json.loads(emails_json)
304
- if not isinstance(emails, list):
305
- return json.dumps({"error": "Input must be a JSON array of strings"})
306
-
307
- if len(emails) > 400:
308
- return json.dumps({"error": "Maximum 400 emails per batch"})
309
-
310
- results = extract_batch(emails)
311
- return json.dumps({"results": results})
312
- except json.JSONDecodeError:
313
- return json.dumps({"error": "Invalid JSON format"})
314
- except Exception as e:
315
- return json.dumps({"error": str(e)})
316
-
317
- def api_process_batch(emails_json: str, threshold: float = 0.5) -> str:
318
- """API endpoint for combined classify + extract in ONE call"""
319
- try:
320
- emails = json.loads(emails_json)
321
- if not isinstance(emails, list):
322
- return json.dumps({"error": "Input must be a JSON array of strings"})
323
-
324
- if len(emails) > 400:
325
- return json.dumps({"error": "Maximum 400 emails per batch"})
326
-
327
- results = process_batch(emails, threshold=threshold)
328
- return json.dumps(results)
329
- except json.JSONDecodeError:
330
- return json.dumps({"error": "Invalid JSON format"})
331
- except Exception as e:
332
- return json.dumps({"error": str(e)})
333
-
334
- # Load models on startup
335
- logger.info("Loading models...")
336
- models_loaded = load_models()
337
-
338
- if not models_loaded:
339
- logger.warning("Models failed to load - using dummy responses")
340
-
341
- # Create Gradio interface
342
- with gr.Blocks(title="Email Classifier & Extractor", theme=gr.themes.Soft()) as demo:
343
- gr.Markdown("# 📧 Email Classification & Extraction API")
344
- gr.Markdown("Classify job-related emails and extract company/role information.")
345
-
346
- with gr.Tab("Classification"):
347
- with gr.Row():
348
- with gr.Column():
349
- classify_input = gr.Textbox(
350
- label="Email Content",
351
- placeholder="Paste your email content here...",
352
- lines=8,
353
- max_lines=20
354
- )
355
- classify_btn = gr.Button("Classify Email", variant="primary")
356
-
357
- with gr.Column():
358
- classify_output = gr.Markdown(label="Classification Result")
359
-
360
- classify_btn.click(
361
- fn=gradio_classify,
362
- inputs=classify_input,
363
- outputs=classify_output
364
- )
365
-
366
- with gr.Tab("Extraction"):
367
- with gr.Row():
368
- with gr.Column():
369
- extract_input = gr.Textbox(
370
- label="Email Content",
371
- placeholder="Paste job application email here...",
372
- lines=8,
373
- max_lines=20
374
- )
375
- extract_btn = gr.Button("Extract Info", variant="primary")
376
-
377
- with gr.Column():
378
- extract_output = gr.Markdown(label="Extraction Result")
379
-
380
- extract_btn.click(
381
- fn=gradio_extract,
382
- inputs=extract_input,
383
- outputs=extract_output
384
- )
385
-
386
- with gr.Tab("API Testing"):
387
- gr.Markdown("### Test API Endpoints")
388
-
389
- with gr.Row():
390
- with gr.Column():
391
- gr.Markdown("**Single Classification**")
392
- api_classify_input = gr.Textbox(label="Email Text", lines=4)
393
- api_classify_btn = gr.Button("Test Classify API")
394
- api_classify_output = gr.JSON(label="Response")
395
-
396
- api_classify_btn.click(
397
- fn=api_classify,
398
- inputs=api_classify_input,
399
- outputs=api_classify_output,
400
- api_name="classify"
401
- )
402
-
403
- with gr.Column():
404
- gr.Markdown("**Single Extraction**")
405
- api_extract_input = gr.Textbox(label="Email Text", lines=4)
406
- api_extract_btn = gr.Button("Test Extract API")
407
- api_extract_output = gr.JSON(label="Response")
408
-
409
- api_extract_btn.click(
410
- fn=api_extract,
411
- inputs=api_extract_input,
412
- outputs=api_extract_output,
413
- api_name="extract"
414
- )
415
-
416
- with gr.Row():
417
- with gr.Column():
418
- gr.Markdown("**Batch Classification**")
419
- batch_classify_input = gr.Textbox(
420
- label="JSON Array of Emails",
421
- lines=6,
422
- placeholder='["Email 1", "Email 2"]'
423
- )
424
- batch_classify_btn = gr.Button("Test Batch Classify")
425
- batch_classify_output = gr.Code(label="Response", language="json")
426
-
427
- batch_classify_btn.click(
428
- fn=api_classify_batch,
429
- inputs=batch_classify_input,
430
- outputs=batch_classify_output,
431
- api_name="classify_batch"
432
- )
433
-
434
- with gr.Column():
435
- gr.Markdown("**Batch Extraction**")
436
- batch_extract_input = gr.Textbox(
437
- label="JSON Array of Emails",
438
- lines=6,
439
- placeholder='["Email 1", "Email 2"]'
440
- )
441
- batch_extract_btn = gr.Button("Test Batch Extract")
442
- batch_extract_output = gr.Code(label="Response", language="json")
443
-
444
- batch_extract_btn.click(
445
- fn=api_extract_batch,
446
- inputs=batch_extract_input,
447
- outputs=batch_extract_output,
448
- api_name="extract_batch"
449
- )
450
-
451
- with gr.Row():
452
- with gr.Column(scale=2):
453
- gr.Markdown("**🚀 Combined Process (Recommended)**")
454
- gr.Markdown("*Classify + Extract in ONE call - fastest option*")
455
- process_input = gr.Textbox(
456
- label="JSON Array of Emails",
457
- lines=6,
458
- placeholder='["Email 1", "Email 2"]'
459
- )
460
- process_threshold = gr.Slider(
461
- minimum=0.1, maximum=0.9, value=0.5, step=0.1,
462
- label="Classification Threshold"
463
- )
464
- process_btn = gr.Button("Test Process API", variant="primary")
465
- process_output = gr.Code(label="Response", language="json")
466
-
467
- process_btn.click(
468
- fn=api_process_batch,
469
- inputs=[process_input, process_threshold],
470
- outputs=process_output,
471
- api_name="process_batch"
472
- )
473
-
474
- with gr.Tab("Documentation"):
475
- gr.Markdown(f"""
476
- ### Model Status
477
- - **Status:** {'✅ Loaded' if models_loaded else '❌ Failed to load'}
478
- - **Device:** {device if device else 'Not initialized'}
479
- - **Classifier:** SetFit (job categories)
480
- - **Extractor:** T5-small (company/role)
481
-
482
- ### API Endpoints
483
-
484
- #### 1. Single Classification
485
- `/api/classify` - Returns job category label
486
-
487
- #### 2. Single Extraction
488
- `/api/extract` - Returns company and role
489
-
490
- #### 3. Batch Classification
491
- `/api/classify_batch` - Classify multiple emails (max 400)
492
-
493
- #### 4. Batch Extraction
494
- `/api/extract_batch` - Extract from multiple emails (max 400) - **NOW BATCHED!**
495
-
496
- #### 5. 🚀 Combined Process (NEW - FASTEST)
497
- `/api/process_batch` - Classify AND extract in ONE call
498
- - Only extracts from job-related emails
499
- - Reduces API calls from 2 to 1
500
- - Recommended for best performance
501
-
502
- ### Categories
503
- - `applied` - Application submitted
504
- - `rejected` - Application rejected
505
- - `interview` - Interview invitation
506
- - `next-phase` - Next round invitation
507
- - `offer` - Job offer received
508
- - `other` - Not job-related
509
-
510
- ### Usage from Next.js
511
-
512
- **Option 1: Separate calls (existing)**
513
- ```javascript
514
- const classifications = await client.predict("/classify_batch", {{
515
- emails_json: JSON.stringify(emails)
516
- }});
517
-
518
- const extractions = await client.predict("/extract_batch", {{
519
- emails_json: JSON.stringify(jobEmails)
520
- }});
521
-
522
- Option 2: Combined call (recommended - 2x faster)
523
- const results = await client.predict("/process_batch", {{
524
- emails_json: JSON.stringify(emails),
525
- threshold: 0.5
526
- }});
527
- // Returns both classification AND extraction in one call
 
1
+ import gradio as gr
2
+ import torch
3
+ from setfit import SetFitModel
4
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
5
+ import json
6
+ import logging
7
+ import re
8
+ from typing import List, Dict, Any
9
+ import os
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ classifier_model = None
15
+ extractor_model = None
16
+ extractor_tokenizer = None
17
+ device = None
18
+
19
+ def load_models():
20
+ global classifier_model, extractor_model, extractor_tokenizer, device
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ logger.info(f"Using device: {device}")
23
+
24
+ try:
25
+ classifier_name = "Tomiwajin/testClasifier"
26
+ token = os.getenv("HF_TOKEN")
27
+ classifier_model = SetFitModel.from_pretrained(
28
+ classifier_name,
29
+ use_auth_token=token if token else False
30
+ )
31
+ logger.info(f"Classifier loaded: {classifier_name}")
32
+
33
+ extractor_name = "Tomiwajin/email-company-role-extractor"
34
+ extractor_tokenizer = AutoTokenizer.from_pretrained(extractor_name)
35
+ extractor_model = T5ForConditionalGeneration.from_pretrained(extractor_name)
36
+ extractor_model.to(device)
37
+ extractor_model.eval()
38
+ logger.info(f"Extractor loaded: {extractor_name}")
39
+ return True
40
+ except Exception as e:
41
+ logger.error(f"Model loading failed: {e}")
42
+ return False
43
+
44
+ def parse_extraction_result(prediction):
45
+ try:
46
+ fixed = prediction.strip()
47
+ if fixed.startswith('"') and not fixed.startswith('{'):
48
+ fixed = '{' + fixed
49
+ if not fixed.endswith('}'):
50
+ fixed = fixed + '}'
51
+ fixed = re.sub(r'",(\s*)"', '", "', fixed)
52
+ result = json.loads(fixed)
53
+ return {
54
+ "company": result.get("company", "unknown"),
55
+ "role": result.get("role", "unknown"),
56
+ "success": True
57
+ }
58
+ except:
59
+ return {"company": "unknown", "role": "unknown", "success": False}
60
+
61
+ def classify_single_email(email_text):
62
+ if not classifier_model:
63
+ return {"error": "Classifier not loaded", "success": False}
64
+ try:
65
+ email_text = email_text.strip()[:1000]
66
+ predictions = classifier_model.predict([email_text])
67
+ probabilities = classifier_model.predict_proba([email_text])[0]
68
+ return {
69
+ "label": str(predictions[0]),
70
+ "score": round(float(max(probabilities)), 4),
71
+ "success": True
72
+ }
73
+ except Exception as e:
74
+ logger.error(f"Classification error: {e}")
75
+ return {"error": str(e), "success": False}
76
+
77
+ def extract_job_info(email_text):
78
+ if not extractor_model or not extractor_tokenizer:
79
+ return {"error": "Extractor not loaded", "success": False}
80
+ try:
81
+ email_text = email_text.strip()[:1000]
82
+ input_text = f"extract company and role: {email_text}"
83
+ inputs = extractor_tokenizer(
84
+ input_text, return_tensors='pt', max_length=512, truncation=True
85
+ ).to(device)
86
+ with torch.no_grad():
87
+ outputs = extractor_model.generate(
88
+ inputs.input_ids,
89
+ attention_mask=inputs.attention_mask,
90
+ max_length=128,
91
+ num_beams=2,
92
+ early_stopping=True,
93
+ pad_token_id=extractor_tokenizer.pad_token_id
94
+ )
95
+ prediction = extractor_tokenizer.decode(outputs[0], skip_special_tokens=True)
96
+ return parse_extraction_result(prediction)
97
+ except Exception as e:
98
+ logger.error(f"Extraction error: {e}")
99
+ return {"company": "unknown", "role": "unknown", "success": False}
100
+
101
+ def classify_batch_emails(emails):
102
+ if not classifier_model:
103
+ return [{"error": "Model not loaded", "success": False}] * len(emails)
104
+ try:
105
+ cleaned = [e.strip()[:1000] for e in emails]
106
+ predictions = classifier_model.predict(cleaned)
107
+ probabilities = classifier_model.predict_proba(cleaned)
108
+ return [
109
+ {"label": str(p), "score": round(float(max(pr)), 4), "success": True}
110
+ for p, pr in zip(predictions, probabilities)
111
+ ]
112
+ except Exception as e:
113
+ logger.error(f"Batch classification error: {e}")
114
+ return [{"error": str(e), "success": False}] * len(emails)
115
+
116
+ def extract_batch(emails):
117
+ if not extractor_model or not extractor_tokenizer:
118
+ return [{"error": "Extractor not loaded", "success": False}] * len(emails)
119
+ if len(emails) == 0:
120
+ return []
121
+ try:
122
+ cleaned = [e.strip()[:1000] for e in emails]
123
+ input_texts = [f"extract company and role: {e}" for e in cleaned]
124
+ inputs = extractor_tokenizer(
125
+ input_texts, return_tensors='pt', max_length=512,
126
+ truncation=True, padding=True
127
+ ).to(device)
128
+ with torch.no_grad():
129
+ outputs = extractor_model.generate(
130
+ inputs.input_ids,
131
+ attention_mask=inputs.attention_mask,
132
+ max_length=128,
133
+ num_beams=2,
134
+ early_stopping=True,
135
+ pad_token_id=extractor_tokenizer.pad_token_id
136
+ )
137
+ predictions = extractor_tokenizer.batch_decode(outputs, skip_special_tokens=True)
138
+ return [parse_extraction_result(p) for p in predictions]
139
+ except Exception as e:
140
+ logger.error(f"Batch extraction error: {e}")
141
+ return [{"company": "unknown", "role": "unknown", "success": False}] * len(emails)
142
+
143
+ def process_batch(emails, job_labels=None, threshold=0.5):
144
+ if job_labels is None:
145
+ job_labels = ["applied", "rejected", "interview", "next-phase", "offer"]
146
+ classifications = classify_batch_emails(emails)
147
+ job_indices = []
148
+ job_emails = []
149
+ for i, (email, cls) in enumerate(zip(emails, classifications)):
150
+ if cls.get("success") and cls.get("label", "").lower() in job_labels and cls.get("score", 0) >= threshold:
151
+ job_indices.append(i)
152
+ job_emails.append(email)
153
+ extractions = extract_batch(job_emails) if job_emails else []
154
+ results = []
155
+ ext_idx = 0
156
+ for i, cls in enumerate(classifications):
157
+ result = {"classification": cls, "extraction": None}
158
+ if i in job_indices:
159
+ result["extraction"] = extractions[ext_idx]
160
+ ext_idx += 1
161
+ results.append(result)
162
+ return {"results": results, "total": len(emails), "job_related": len(job_emails)}
163
+
164
+ def api_classify_batch(emails_json):
165
+ try:
166
+ emails = json.loads(emails_json)
167
+ if not isinstance(emails, list):
168
+ return json.dumps({"error": "Input must be a JSON array"})
169
+ if len(emails) > 400:
170
+ return json.dumps({"error": "Maximum 400 emails per batch"})
171
+ results = classify_batch_emails(emails)
172
+ return json.dumps({"results": results})
173
+ except json.JSONDecodeError:
174
+ return json.dumps({"error": "Invalid JSON format"})
175
+ except Exception as e:
176
+ return json.dumps({"error": str(e)})
177
+
178
+ def api_extract_batch(emails_json):
179
+ try:
180
+ emails = json.loads(emails_json)
181
+ if not isinstance(emails, list):
182
+ return json.dumps({"error": "Input must be a JSON array"})
183
+ if len(emails) > 400:
184
+ return json.dumps({"error": "Maximum 400 emails per batch"})
185
+ results = extract_batch(emails)
186
+ return json.dumps({"results": results})
187
+ except json.JSONDecodeError:
188
+ return json.dumps({"error": "Invalid JSON format"})
189
+ except Exception as e:
190
+ return json.dumps({"error": str(e)})
191
+
192
+ def api_process_batch(emails_json, threshold=0.5):
193
+ try:
194
+ emails = json.loads(emails_json)
195
+ if not isinstance(emails, list):
196
+ return json.dumps({"error": "Input must be a JSON array"})
197
+ if len(emails) > 400:
198
+ return json.dumps({"error": "Maximum 400 emails per batch"})
199
+ results = process_batch(emails, threshold=threshold)
200
+ return json.dumps(results)
201
+ except json.JSONDecodeError:
202
+ return json.dumps({"error": "Invalid JSON format"})
203
+ except Exception as e:
204
+ return json.dumps({"error": str(e)})
205
+
206
+ logger.info("Loading models...")
207
+ models_loaded = load_models()
208
+
209
+ with gr.Blocks(title="Email Classifier & Extractor", theme=gr.themes.Soft()) as demo:
210
+ gr.Markdown("# Email Classification & Extraction API")
211
+
212
+ with gr.Tab("Batch Classification"):
213
+ batch_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]')
214
+ batch_btn = gr.Button("Classify Batch")
215
+ batch_output = gr.Code(label="Response", language="json")
216
+ batch_btn.click(fn=api_classify_batch, inputs=batch_input, outputs=batch_output, api_name="classify_batch")
217
+
218
+ with gr.Tab("Batch Extraction"):
219
+ extract_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]')
220
+ extract_btn = gr.Button("Extract Batch")
221
+ extract_output = gr.Code(label="Response", language="json")
222
+ extract_btn.click(fn=api_extract_batch, inputs=extract_input, outputs=extract_output, api_name="extract_batch")
223
+
224
+ with gr.Tab("Combined Process"):
225
+ process_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]')
226
+ process_threshold = gr.Slider(minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="Threshold")
227
+ process_btn = gr.Button("Process Batch", variant="primary")
228
+ process_output = gr.Code(label="Response", language="json")
229
+ process_btn.click(fn=api_process_batch, inputs=[process_input, process_threshold], outputs=process_output, api_name="process_batch")
230
+
231
+ with gr.Tab("Status"):
232
+ status_text = "Loaded" if models_loaded else "Failed"
233
+ gr.Markdown(f"**Model Status:** {status_text}")
234
+
235
+ if __name__ == "__main__":
236
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)