Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from setfit import SetFitModel | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| import json | |
| import logging | |
| import re | |
| from typing import List, Dict, Any | |
| import os | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| classifier_model = None | |
| extractor_model = None | |
| extractor_tokenizer = None | |
| device = None | |
| def load_models(): | |
| global classifier_model, extractor_model, extractor_tokenizer, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| try: | |
| classifier_name = "Tomiwajin/testClasifier" | |
| token = os.getenv("HF_TOKEN") | |
| classifier_model = SetFitModel.from_pretrained( | |
| classifier_name, | |
| use_auth_token=token if token else False | |
| ) | |
| logger.info(f"Classifier loaded: {classifier_name}") | |
| extractor_name = "Tomiwajin/email-company-role-extractor" | |
| extractor_tokenizer = AutoTokenizer.from_pretrained(extractor_name) | |
| extractor_model = T5ForConditionalGeneration.from_pretrained(extractor_name) | |
| extractor_model.to(device) | |
| extractor_model.eval() | |
| logger.info(f"Extractor loaded: {extractor_name}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| return False | |
| def parse_extraction_result(prediction): | |
| try: | |
| fixed = prediction.strip() | |
| if fixed.startswith('"') and not fixed.startswith('{'): | |
| fixed = '{' + fixed | |
| if not fixed.endswith('}'): | |
| fixed = fixed + '}' | |
| fixed = re.sub(r'",(\s*)"', '", "', fixed) | |
| result = json.loads(fixed) | |
| return { | |
| "company": result.get("company", "unknown"), | |
| "role": result.get("role", "unknown"), | |
| "success": True | |
| } | |
| except: | |
| return {"company": "unknown", "role": "unknown", "success": False} | |
| def classify_single_email(email_text): | |
| if not classifier_model: | |
| return {"error": "Classifier not loaded", "success": False} | |
| try: | |
| email_text = email_text.strip()[:1000] | |
| predictions = classifier_model.predict([email_text]) | |
| probabilities = classifier_model.predict_proba([email_text])[0] | |
| return { | |
| "label": str(predictions[0]), | |
| "score": round(float(max(probabilities)), 4), | |
| "success": True | |
| } | |
| except Exception as e: | |
| logger.error(f"Classification error: {e}") | |
| return {"error": str(e), "success": False} | |
| def extract_job_info(email_text): | |
| if not extractor_model or not extractor_tokenizer: | |
| return {"error": "Extractor not loaded", "success": False} | |
| try: | |
| email_text = email_text.strip()[:1000] | |
| input_text = f"extract company and role: {email_text}" | |
| inputs = extractor_tokenizer( | |
| input_text, return_tensors='pt', max_length=512, truncation=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = extractor_model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=128, | |
| num_beams=2, | |
| early_stopping=True, | |
| pad_token_id=extractor_tokenizer.pad_token_id | |
| ) | |
| prediction = extractor_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return parse_extraction_result(prediction) | |
| except Exception as e: | |
| logger.error(f"Extraction error: {e}") | |
| return {"company": "unknown", "role": "unknown", "success": False} | |
| def classify_batch_emails(emails): | |
| if not classifier_model: | |
| return [{"error": "Model not loaded", "success": False}] * len(emails) | |
| try: | |
| cleaned = [e.strip()[:1000] for e in emails] | |
| predictions = classifier_model.predict(cleaned) | |
| probabilities = classifier_model.predict_proba(cleaned) | |
| return [ | |
| {"label": str(p), "score": round(float(max(pr)), 4), "success": True} | |
| for p, pr in zip(predictions, probabilities) | |
| ] | |
| except Exception as e: | |
| logger.error(f"Batch classification error: {e}") | |
| return [{"error": str(e), "success": False}] * len(emails) | |
| def extract_batch(emails): | |
| if not extractor_model or not extractor_tokenizer: | |
| return [{"error": "Extractor not loaded", "success": False}] * len(emails) | |
| if len(emails) == 0: | |
| return [] | |
| try: | |
| cleaned = [e.strip()[:1000] for e in emails] | |
| input_texts = [f"extract company and role: {e}" for e in cleaned] | |
| inputs = extractor_tokenizer( | |
| input_texts, return_tensors='pt', max_length=512, | |
| truncation=True, padding=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = extractor_model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=128, | |
| num_beams=2, | |
| early_stopping=True, | |
| pad_token_id=extractor_tokenizer.pad_token_id | |
| ) | |
| predictions = extractor_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| return [parse_extraction_result(p) for p in predictions] | |
| except Exception as e: | |
| logger.error(f"Batch extraction error: {e}") | |
| return [{"company": "unknown", "role": "unknown", "success": False}] * len(emails) | |
| def process_batch(emails, job_labels=None, threshold=0.5): | |
| if job_labels is None: | |
| job_labels = ["applied", "rejected", "interview", "next-phase", "offer"] | |
| classifications = classify_batch_emails(emails) | |
| job_indices = [] | |
| job_emails = [] | |
| for i, (email, cls) in enumerate(zip(emails, classifications)): | |
| if cls.get("success") and cls.get("label", "").lower() in job_labels and cls.get("score", 0) >= threshold: | |
| job_indices.append(i) | |
| job_emails.append(email) | |
| extractions = extract_batch(job_emails) if job_emails else [] | |
| results = [] | |
| ext_idx = 0 | |
| for i, cls in enumerate(classifications): | |
| result = {"classification": cls, "extraction": None} | |
| if i in job_indices: | |
| result["extraction"] = extractions[ext_idx] | |
| ext_idx += 1 | |
| results.append(result) | |
| return {"results": results, "total": len(emails), "job_related": len(job_emails)} | |
| def api_classify_batch(emails_json): | |
| try: | |
| emails = json.loads(emails_json) | |
| if not isinstance(emails, list): | |
| return json.dumps({"error": "Input must be a JSON array"}) | |
| if len(emails) > 400: | |
| return json.dumps({"error": "Maximum 400 emails per batch"}) | |
| results = classify_batch_emails(emails) | |
| return json.dumps({"results": results}) | |
| except json.JSONDecodeError: | |
| return json.dumps({"error": "Invalid JSON format"}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| def api_extract_batch(emails_json): | |
| try: | |
| emails = json.loads(emails_json) | |
| if not isinstance(emails, list): | |
| return json.dumps({"error": "Input must be a JSON array"}) | |
| if len(emails) > 400: | |
| return json.dumps({"error": "Maximum 400 emails per batch"}) | |
| results = extract_batch(emails) | |
| return json.dumps({"results": results}) | |
| except json.JSONDecodeError: | |
| return json.dumps({"error": "Invalid JSON format"}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| def api_process_batch(emails_json, threshold=0.5): | |
| try: | |
| emails = json.loads(emails_json) | |
| if not isinstance(emails, list): | |
| return json.dumps({"error": "Input must be a JSON array"}) | |
| if len(emails) > 400: | |
| return json.dumps({"error": "Maximum 400 emails per batch"}) | |
| results = process_batch(emails, threshold=threshold) | |
| return json.dumps(results) | |
| except json.JSONDecodeError: | |
| return json.dumps({"error": "Invalid JSON format"}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| logger.info("Loading models...") | |
| models_loaded = load_models() | |
| with gr.Blocks(title="Email Classifier & Extractor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Email Classification & Extraction API") | |
| with gr.Tab("Batch Classification"): | |
| batch_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') | |
| batch_btn = gr.Button("Classify Batch") | |
| batch_output = gr.Code(label="Response", language="json") | |
| batch_btn.click(fn=api_classify_batch, inputs=batch_input, outputs=batch_output, api_name="classify_batch") | |
| with gr.Tab("Batch Extraction"): | |
| extract_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') | |
| extract_btn = gr.Button("Extract Batch") | |
| extract_output = gr.Code(label="Response", language="json") | |
| extract_btn.click(fn=api_extract_batch, inputs=extract_input, outputs=extract_output, api_name="extract_batch") | |
| with gr.Tab("Combined Process"): | |
| process_input = gr.Textbox(label="JSON Array of Emails", lines=6, placeholder='["email1", "email2"]') | |
| process_threshold = gr.Slider(minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="Threshold") | |
| process_btn = gr.Button("Process Batch", variant="primary") | |
| process_output = gr.Code(label="Response", language="json") | |
| process_btn.click(fn=api_process_batch, inputs=[process_input, process_threshold], outputs=process_output, api_name="process_batch") | |
| with gr.Tab("Status"): | |
| status_text = "Loaded" if models_loaded else "Failed" | |
| gr.Markdown(f"**Model Status:** {status_text}") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True) | |