import gradio as gr import torch from setfit import SetFitModel from transformers import AutoTokenizer, T5ForConditionalGeneration import json import logging import re from typing import List, Dict, Any import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) classifier_model = None extractor_model = None extractor_tokenizer = None device = None def load_models(): global classifier_model, extractor_model, extractor_tokenizer, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") try: classifier_name = "Tomiwajin/testClasifier" token = os.getenv("HF_TOKEN") classifier_model = SetFitModel.from_pretrained( classifier_name, use_auth_token=token if token else False ) logger.info(f"Classifier loaded: {classifier_name}") extractor_name = "Tomiwajin/email-company-role-extractor" extractor_tokenizer = AutoTokenizer.from_pretrained(extractor_name) extractor_model = T5ForConditionalGeneration.from_pretrained(extractor_name) extractor_model.to(device) extractor_model.eval() logger.info(f"Extractor loaded: {extractor_name}") return True except Exception as e: logger.error(f"Model loading failed: {e}") return False def parse_extraction_result(prediction): try: fixed = prediction.strip() if fixed.startswith('"') and not fixed.startswith('{'): fixed = '{' + fixed if not fixed.endswith('}'): fixed = fixed + '}' fixed = re.sub(r'",(\s*)"', '", "', fixed) result = json.loads(fixed) return { "company": result.get("company", "unknown"), "role": result.get("role", "unknown"), "success": True } except: return {"company": "unknown", "role": "unknown", "success": False} def classify_single_email(email_text): if not classifier_model: return {"error": "Classifier not loaded", "success": False} try: email_text = email_text.strip()[:1000] predictions = classifier_model.predict([email_text]) probabilities = classifier_model.predict_proba([email_text])[0] return { "label": str(predictions[0]), "score": round(float(max(probabilities)), 4), "success": True } except Exception as e: logger.error(f"Classification error: {e}") return {"error": str(e), "success": False} def extract_job_info(email_text): if not extractor_model or not extractor_tokenizer: return {"error": "Extractor not loaded", "success": False} try: email_text = email_text.strip()[:1000] input_text = f"extract company and role: {email_text}" inputs = extractor_tokenizer( input_text, return_tensors='pt', max_length=512, truncation=True ).to(device) with torch.no_grad(): outputs = extractor_model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_length=128, num_beams=2, early_stopping=True, pad_token_id=extractor_tokenizer.pad_token_id ) prediction = extractor_tokenizer.decode(outputs[0], skip_special_tokens=True) return parse_extraction_result(prediction) except Exception as e: logger.error(f"Extraction error: {e}") return {"company": "unknown", "role": "unknown", "success": False} def classify_batch_emails(emails): if not classifier_model: return [{"error": "Model not loaded", "success": False}] * len(emails) try: cleaned = [e.strip()[:1000] for e in emails] predictions = classifier_model.predict(cleaned) probabilities = classifier_model.predict_proba(cleaned) return [ {"label": str(p), "score": round(float(max(pr)), 4), "success": True} for p, pr in zip(predictions, probabilities) ] except Exception as e: logger.error(f"Batch classification error: {e}") return [{"error": str(e), "success": False}] * len(emails) def extract_batch(emails): if not extractor_model or not extractor_tokenizer: return [{"error": "Extractor not loaded", "success": False}] * len(emails) if len(emails) == 0: return [] try: cleaned = [e.strip()[:1000] for e in emails] input_texts = [f"extract company and role: {e}" for e in cleaned] inputs = extractor_tokenizer( input_texts, return_tensors='pt', max_length=512, truncation=True, padding=True ).to(device) with torch.no_grad(): outputs = extractor_model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_length=128, num_beams=2, early_stopping=True, pad_token_id=extractor_tokenizer.pad_token_id ) predictions = extractor_tokenizer.batch_decode(outputs, skip_special_tokens=True) return [parse_extraction_result(p) for p in predictions] except Exception as e: logger.error(f"Batch extraction error: {e}") return [{"company": "unknown", "role": "unknown", "success": False}] * len(emails) def process_batch(emails, job_labels=None, threshold=0.5): if job_labels is None: job_labels = ["applied", "rejected", "interview", "next-phase", "offer"] classifications = classify_batch_emails(emails) job_indices = [] job_emails = [] for i, (email, cls) in enumerate(zip(emails, classifications)): if cls.get("success") and cls.get("label", "").lower() in job_labels and cls.get("score", 0) >= threshold: job_indices.append(i) job_emails.append(email) extractions = extract_batch(job_emails) if job_emails else [] results = [] ext_idx = 0 for i, cls in enumerate(classifications): result = {"classification": cls, "extraction": None} if i in job_indices: result["extraction"] = extractions[ext_idx] ext_idx += 1 results.append(result) return {"results": results, "total": len(emails), "job_related": len(job_emails)} def api_classify_batch(emails_json): try: emails = json.loads(emails_json) if not isinstance(emails, list): return json.dumps({"error": "Input must be a JSON array"}) if len(emails) > 400: return json.dumps({"error": "Maximum 400 emails per batch"}) results = classify_batch_emails(emails) return json.dumps({"results": results}) except json.JSONDecodeError: return json.dumps({"error": "Invalid JSON format"}) except Exception as e: return json.dumps({"error": str(e)}) def api_extract_batch(emails_json): try: emails = json.loads(emails_json) if not isinstance(emails, list): return json.dumps({"error": "Input must be a JSON array"}) if len(emails) > 400: return json.dumps({"error": "Maximum 400 emails per batch"}) results = extract_batch(emails) return json.dumps({"results": results}) except json.JSONDecodeError: return json.dumps({"error": "Invalid JSON format"}) except Exception as e: return json.dumps({"error": str(e)}) def api_process_batch(emails_json, threshold=0.5): try: emails = json.loads(emails_json) if not isinstance(emails, list): return json.dumps({"error": "Input must be a JSON array"}) if len(emails) > 400: return json.dumps({"error": "Maximum 400 emails per batch"}) results = process_batch(emails, threshold=threshold) return json.dumps(results) except json.JSONDecodeError: return json.dumps({"error": "Invalid JSON format"}) except Exception as e: return json.dumps({"error": str(e)}) logger.info("Loading models...") models_loaded = load_models() with gr.Blocks(title="Email Classifier & Extractor", theme=gr.themes.Soft()) as demo: gr.Markdown("# Email Classification & Extraction API") with gr.Tab("Batch Classification"): batch_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') batch_btn = gr.Button("Classify Batch") batch_output = gr.Code(label="Response", language="json") batch_btn.click(fn=api_classify_batch, inputs=batch_input, outputs=batch_output, api_name="classify_batch") with gr.Tab("Batch Extraction"): extract_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') extract_btn = gr.Button("Extract Batch") extract_output = gr.Code(label="Response", language="json") extract_btn.click(fn=api_extract_batch, inputs=extract_input, outputs=extract_output, api_name="extract_batch") with gr.Tab("Combined Process"): process_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') process_threshold = gr.Slider(minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="Threshold") process_btn = gr.Button("Process Batch", variant="primary") process_output = gr.Code(label="Response", language="json") process_btn.click(fn=api_process_batch, inputs=[process_input, process_threshold], outputs=process_output, api_name="process_batch") with gr.Tab("Status"): status_text = "Loaded" if models_loaded else "Failed" gr.Markdown(f"**Model Status:** {status_text}") if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)