JAAT_DEMO / app.py
pnorlander
Fix float16 dtype mismatch in TitleMatch + all matchers, remove .DS_Store, parallelize line-by-line
2a9e653
import nltk
nltk.download('punkt_tab', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
# ── Compatibility shim for JAAT 0.7.x + transformers β‰₯ 4.45 ─────────────────
# JAAT's WageExtract.__init__ calls
# AutoModelForTokenClassification.from_pretrained(..., max_length=128, ...)
# which newer transformers rejects ("unexpected keyword argument 'max_length'")
# because the kwarg is routed into the model subclass __init__. `max_length`
# is a tokenizer-side option and safe to drop for the model load.
from transformers import AutoModelForTokenClassification as _AMTC
_orig_amtc_fp = _AMTC.from_pretrained
def _patched_amtc_fp(*args, **kwargs):
kwargs.pop("max_length", None)
return _orig_amtc_fp(*args, **kwargs)
_AMTC.from_pretrained = _patched_amtc_fp
# ── TitleMatch compat: DebertaV2Tokenizer.batch_encode_plus ─────────────────
# JAAT.TitleMatch.get_title calls feature_tokenizer.batch_encode_plus(...).
# Newer transformers versions raise "DebertaV2Tokenizer has no attribute
# batch_encode_plus" for the slow DebertaV2 tokenizer. `__call__` accepts the
# same kwargs and returns the same BatchEncoding, so alias it.
try:
from transformers.models.deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
if not hasattr(DebertaV2Tokenizer, "batch_encode_plus"):
def _dv2_batch_encode_plus(self, batch_text_or_text_pairs, **kwargs):
return self(batch_text_or_text_pairs, **kwargs)
DebertaV2Tokenizer.batch_encode_plus = _dv2_batch_encode_plus
except Exception as _e:
print(f"[titlematch-shim] could not patch DebertaV2Tokenizer: {_e}")
import gradio as gr
from sentence_transformers import util as _st_util, SentenceTransformer as _ST
from nltk.tokenize import sent_tokenize as _sent_tokenize
import time
# ── Embedding cache shim ────────────────────────────────────────────────────
# JAAT.TaskMatch / SkillMatch / AIMatch each call embedding_model.encode() on
# a large fixed corpus inside __init__. On CPU that's ~30 min of encoding at
# every cold start β€” enough to fail the Space's 30-min launch health check.
# We precompute those tensors once and host them at pnorlander/jaat-embeddings;
# here we download them plus the corresponding corpus lists, and monkey-patch
# SentenceTransformer.encode so that when JAAT feeds it a known corpus we
# return the cached tensor instead of re-encoding. Query-time encode() calls
# (small inputs) fall through to the real encoder.
#
# Lookup is value-based (not hash-based), so the cache tolerates any corpus
# ordering β€” e.g. SkillMatch builds its list via `list(set(...))` whose
# iteration order varies between Python versions.
import json
import os
EMBEDDINGS_REPO_ID = os.environ.get("JAAT_EMBEDDINGS_REPO", "pnorlander/jaat-embeddings")
# Each entry: {"pos": {item_text: row_idx}, "tensor": <Tensor>}
_emb_caches = []
def _load_embedding_cache():
import torch
from huggingface_hub import hf_hub_download
try:
manifest_path = hf_hub_download(repo_id=EMBEDDINGS_REPO_ID, filename="manifest.json")
except Exception as e:
print(f"[emb-cache] manifest download failed ({e}); falling back to live encoding.")
return
with open(manifest_path) as f:
manifest = json.load(f)
for name, entry in manifest.items():
corpus_file = entry.get("corpus_file")
if not corpus_file:
print(f"[emb-cache] {name}: no corpus_file in manifest; skipping")
continue
try:
pt_path = hf_hub_download(repo_id=EMBEDDINGS_REPO_ID, filename=entry["file"])
corpus_path = hf_hub_download(repo_id=EMBEDDINGS_REPO_ID, filename=corpus_file)
tensor = torch.load(pt_path, map_location="cpu", weights_only=True)
with open(corpus_path) as f:
corpus = json.load(f)
if len(corpus) != tensor.shape[0]:
print(f"[emb-cache] {name}: corpus/tensor size mismatch "
f"({len(corpus)} vs {tensor.shape[0]}); skipping")
continue
pos = {item: i for i, item in enumerate(corpus)}
_emb_caches.append({"name": name, "pos": pos, "tensor": tensor})
print(f"[emb-cache] loaded {name}: {tuple(tensor.shape)} ({len(pos)} items)")
except Exception as e:
print(f"[emb-cache] {name} failed: {e}")
_load_embedding_cache()
_orig_encode = _ST.encode
def _patched_encode(self, sentences, *args, **kwargs):
import torch
if isinstance(sentences, list) and len(sentences) >= 100 and _emb_caches:
for entry in _emb_caches:
pos = entry["pos"]
if all(s in pos for s in sentences):
idxs = [pos[s] for s in sentences]
reordered = entry["tensor"][torch.tensor(idxs)]
print(f"[emb-cache] HIT on {entry['name']} ({len(sentences)} items) β€” skipping encode")
if kwargs.get("convert_to_tensor", False):
return reordered
return reordered.numpy()
return _orig_encode(self, sentences, *args, **kwargs)
_ST.encode = _patched_encode
from JAAT import JAAT
# Newer sentence-transformers ships gte-small/gte-large weights as float16,
# so SentenceTransformer.encode() returns float16 tensors. Corpus embeddings
# (from pickle or from the precomputed cache) are float32. torch.mm refuses
# to mix dtypes, so we cast every embedding model to float32 after init.
_orig_titlematch_init = JAAT.TitleMatch.__init__
def _patched_titlematch_init(self, *args, **kwargs):
_orig_titlematch_init(self, *args, **kwargs)
import torch
self.title_embed = self.title_embed.to(torch.float32)
self.embedding_model = self.embedding_model.float()
print(f"[titlematch-shim] title_embed & embedding_model β†’ float32")
JAAT.TitleMatch.__init__ = _patched_titlematch_init
# ── Initialize JAAT modules once at startup ──────────────────────────────────
print("Loading JAAT modules (this may take a moment)...")
task_matcher = JAAT.TaskMatch()
title_matcher = JAAT.TitleMatch()
firm_extractor = JAAT.FirmExtract()
wage_extractor = JAAT.WageExtract()
skill_matcher = JAAT.SkillMatch()
ai_matcher = JAAT.AIMatch()
import torch as _torch
for _m in (task_matcher, skill_matcher, ai_matcher):
_m.embedding_model = _m.embedding_model.float()
task_matcher.task_embed = task_matcher.task_embed.to(_torch.float32)
skill_matcher.skill_embed = skill_matcher.skill_embed.to(_torch.float32)
ai_matcher.ai_embed = ai_matcher.ai_embed.to(_torch.float32)
print("[dtype-shim] TaskMatch/SkillMatch/AIMatch embedding models & corpus β†’ float32")
JOBTAG_CLASSES = [
"CitizenshipReq", "GovContract", "VisaExclude", "VisaInclude",
"WorkAuthReq", "driverslicense", "ind_contractor",
"proflicenses", "wfh", "yesunion",
]
TAG_LABELS = {
"CitizenshipReq": "Citizenship Required",
"GovContract": "Government Contract",
"VisaExclude": "Visa Excluded",
"VisaInclude": "Visa Sponsorship",
"WorkAuthReq": "Work Auth Required",
"driverslicense": "Driver's License",
"ind_contractor": "Independent Contractor",
"proflicenses": "Professional License",
"wfh": "Work From Home",
"yesunion": "Union Position",
}
job_taggers = {cls: JAAT.JobTag(class_name=cls) for cls in JOBTAG_CLASSES}
print("All modules loaded. Ready!")
def format_status(tool_states):
"""Build a markdown status summary of the pipeline."""
icons = {"pending": "⏳", "running": "πŸ”„", "done": "βœ…", "error": "❌"}
lines = []
for tool, state in tool_states.items():
lines.append(f"{icons[state]} **{tool}** β€” {state}")
return "\n\n".join(lines)
# ── Line-by-line attribution ────────────────────────────────────────────────
# JAAT's TaskMatch/SkillMatch/AIMatch all split the input with the same
# preprocessing before classifying each sentence. Replicating that split lets
# us line the candidate positives back up with their source sentence.
def _jaat_split(text):
t = ". ".join(text.split("\n"))
for a, b in [(";", "."), (" + ", ". "), (" * ", ". "), (" - ", ". "),
(" β€’ ", ". "), (" Β· ", ". "), ("--", ". "), ("**", ". ")]:
t = t.replace(a, b)
return _sent_tokenize(t.strip())
def _attribute(sentences, matcher, corpus_attr, label_fn, threshold,
min_words=0, max_words=10**9):
"""Run matcher.embedding_model + semantic_search over `sentences` against
`getattr(matcher, corpus_attr)`, returning {sentence_idx: [labels]}.
`label_fn(corpus_id)` produces the display string for a hit."""
idxs, texts = [], []
for i, s in enumerate(sentences):
wc = len(s.split())
if min_words <= wc <= max_words:
idxs.append(i)
texts.append(s)
if not texts:
return {}
# Classify sentences as positive candidates
preds = list(matcher.pipe(texts))
positive_pairs = [(i, t) for (i, t), p in zip(zip(idxs, texts), preds)
if p.get("label") == "LABEL_1"]
if not positive_pairs:
return {}
pos_texts = [t for _, t in positive_pairs]
q = matcher.embedding_model.encode(pos_texts, convert_to_tensor=True, batch_size=64)
if getattr(matcher, "device", "cpu") == "cuda":
q = q.to("cuda")
hits = _st_util.semantic_search(
corpus_embeddings=getattr(matcher, corpus_attr),
query_embeddings=q,
top_k=1,
)
out = {}
for (sent_idx, _), h in zip(positive_pairs, hits):
score = h[0]["score"]
if score >= threshold:
out.setdefault(sent_idx, []).append(label_fn(h[0]["corpus_id"]))
return out
def _attribution_table(job_text):
from concurrent.futures import ThreadPoolExecutor
sents = _jaat_split(job_text)
if not sents:
return "_No sentences to attribute._"
with ThreadPoolExecutor(max_workers=3) as pool:
task_fut = pool.submit(
_attribute, sents, task_matcher, "task_embed",
lambda cid: (
f"**{task_matcher.tasks.iloc[cid]['Task ID']}** "
f"{task_matcher.tasks.iloc[cid]['Task']}"
),
task_matcher.threshold, 1, 48,
)
skill_fut = pool.submit(
_attribute, sents, skill_matcher, "skill_embed",
lambda cid: (
f"**{skill_matcher.skill_map[skill_matcher.skills[cid]]}** "
f"{skill_matcher.skills[cid]}"
),
skill_matcher.threshold, 1, 48,
)
ai_fut = pool.submit(
_attribute, sents, ai_matcher, "ai_embed",
lambda cid: (
f"**{ai_matcher.ai_map[ai_matcher.ai[cid]]}** "
f"{ai_matcher.ai[cid]}"
),
ai_matcher.threshold, 4, 64,
)
task_hits = task_fut.result()
skill_hits = skill_fut.result()
ai_hits = ai_fut.result()
def esc(s):
return s.replace("|", "\\|").replace("\n", " ").strip()
lines = [
"| # | Sentence | Tasks (O*NET) | Skills (ESCO) | AI Concepts |",
"|---|----------|---------------|---------------|-------------|",
]
for i, s in enumerate(sents):
t = "<br>".join(task_hits.get(i, [])) or "β€”"
sk = "<br>".join(skill_hits.get(i, [])) or "β€”"
a = "<br>".join(ai_hits.get(i, [])) or "β€”"
lines.append(f"| {i+1} | {esc(s)} | {t} | {sk} | {a} |")
return "\n".join(lines)
def analyze(job_title, job_text, mode="Summary", progress=gr.Progress(track_tqdm=True)):
line_by_line = (mode == "Line-by-line")
if not job_text.strip():
yield "Please paste a job ad first.", "", "", "", "", "", "", "", ""
return
tools = ["FirmExtract", "WageExtract", "TitleMatch", "TaskMatch", "SkillMatch", "AIMatch", "JobTag"]
states = {t: "pending" for t in tools}
firm_out = ""
wage_out = ""
title_out = ""
task_out = ""
skill_out = ""
ai_out = ""
tag_out = ""
line_out = ""
# ── FirmExtract ──────────────────────────────────────────────────────
states["FirmExtract"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
try:
tagged = firm_extractor.pipe(job_text)
firm = firm_extractor.extract_firm(tagged, return_one=True, return_score=False)
firm_out = firm if firm else "Not detected"
states["FirmExtract"] = "done"
except Exception as e:
firm_out = f"Error: {e}"
states["FirmExtract"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── WageExtract ──────────────────────────────────────────────────────
states["WageExtract"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
try:
wage = wage_extractor.get_wage(job_text)
if isinstance(wage, dict) and wage:
parts = []
if wage.get("min"):
try:
parts.append(f"**Min:** ${float(wage['min'].replace(',','')):,.2f}")
except (ValueError, AttributeError):
parts.append(f"**Min:** {wage['min']}")
if wage.get("max"):
try:
parts.append(f"**Max:** ${float(wage['max'].replace(',','')):,.2f}")
except (ValueError, AttributeError):
parts.append(f"**Max:** {wage['max']}")
if wage.get("frequency"):
parts.append(f"**Frequency:** {wage['frequency']}")
wage_out = " | ".join(parts) if parts else "Not found"
elif isinstance(wage, str):
wage_out = wage
else:
wage_out = "Not found"
states["WageExtract"] = "done"
except Exception as e:
wage_out = f"Error: {e}"
states["WageExtract"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── TitleMatch ───────────────────────────────────────────────────────
states["TitleMatch"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
if job_title.strip():
try:
titles = title_matcher.get_title(job_title.strip())
if titles:
t = titles[0]
# get_title returns (onet_code, score, value, features)
onet_code = t[0]
score_pct = f"{float(t[1]) * 100:.1f}%"
onet_url = f"https://www.onetonline.org/link/summary/{onet_code}"
title_out = (
f"**O*NET Code:** [{onet_code}]({onet_url})\n\n"
f"**Score:** {score_pct}"
)
else:
title_out = "No match found"
states["TitleMatch"] = "done"
except Exception as e:
title_out = f"Error: {e}"
states["TitleMatch"] = "error"
else:
title_out = "No job title provided (enter one above to use TitleMatch)"
states["TitleMatch"] = "done"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── TaskMatch ────────────────────────────────────────────────────────
states["TaskMatch"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
if line_by_line:
task_out = "_See Line-by-line Attribution below._"
states["TaskMatch"] = "done"
else:
try:
tasks = task_matcher.get_tasks(job_text)
if tasks:
lines = [f"| {t[0]} | {t[1]} |" for t in tasks]
task_out = "| Task ID | Description |\n|---------|-------------|\n" + "\n".join(lines)
else:
task_out = "No O*NET tasks matched in this ad."
states["TaskMatch"] = "done"
except Exception as e:
task_out = f"Error: {e}"
states["TaskMatch"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── SkillMatch ───────────────────────────────────────────────────────
states["SkillMatch"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
if line_by_line:
skill_out = "_See Line-by-line Attribution below._"
states["SkillMatch"] = "done"
else:
try:
skills = skill_matcher.get_skills(job_text)
if skills:
lines = [f"| {s[1]} | {s[0]} |" for s in skills]
skill_out = "| ESCO Code | Skill |\n|-----------|-------|\n" + "\n".join(lines)
else:
skill_out = "No skills matched in this ad."
states["SkillMatch"] = "done"
except Exception as e:
skill_out = f"Error: {e}"
states["SkillMatch"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── AIMatch ─────────────────────────────────────────────────────────
states["AIMatch"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
if line_by_line:
ai_out = "_See Line-by-line Attribution below._"
states["AIMatch"] = "done"
else:
try:
ai_result = ai_matcher.get_ai(job_text)
if ai_result and isinstance(ai_result, (list, tuple)) and len(ai_result) >= 3:
matched_ai, count, avg_score, binary_scores, match_scores = ai_result
if matched_ai:
# Sort by last 5 digits of code descending
indexed = list(zip(matched_ai, match_scores))
indexed.sort(key=lambda x: x[0][1][-5:], reverse=True)
lines = []
for (statement, code), ms in indexed:
lines.append(f"| {code} | {statement} | {ms} |")
ai_out = (
f"**AI Concepts Found:** {count} | **Avg Score:** {avg_score}\n\n"
"| Code | Statement | Match Score |\n|------|-----------|-------------|\n"
+ "\n".join(lines)
)
else:
ai_out = "No AI-related concepts detected."
else:
ai_out = "No AI-related concepts detected."
states["AIMatch"] = "done"
except Exception as e:
ai_out = f"Error: {e}"
states["AIMatch"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── JobTag ───────────────────────────────────────────────────────────
states["JobTag"] = "running"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
try:
tag_lines = []
for cls, tagger in job_taggers.items():
pred = tagger.get_tag(job_text)
detected = bool(pred[1])
label = TAG_LABELS.get(cls, cls)
icon = "βœ…" if detected else "β€”"
tag_lines.append(f"| {label} | {icon} |")
tag_out = "| Attribute | Detected |\n|-----------|----------|\n" + "\n".join(tag_lines)
states["JobTag"] = "done"
except Exception as e:
tag_out = f"Error: {e}"
states["JobTag"] = "error"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
# ── Line-by-line Attribution (Tasks/Skills/AI) ──────────────────────
if line_by_line:
try:
line_out = _attribution_table(job_text)
except Exception as e:
line_out = f"Error building attribution table: {e}"
yield format_status(states), firm_out, wage_out, title_out, task_out, skill_out, ai_out, tag_out, line_out
CITATION = """**Software & Data Citation**
If you use JAAT in your research, please cite:
```bibtex
@article{meisenbacher2025extracting,
title={Extracting O*NET Features from the NLx Corpus to Build
Public Use Aggregate Labor Market Data},
author={Meisenbacher, Stephen and Nestorov, Svetlozar
and Norlander, Peter},
journal={arXiv preprint arXiv:2510.01470},
year={2025}
}
```
"""
with gr.Blocks(title="JAAT β€” Job Ad Analysis Toolkit") as demo:
gr.Markdown("""
# JAAT β€” Job Ad Analysis Toolkit
Paste a job advertisement to extract O*NET tasks, skills, title match, firm name, wages, and job tags.
[GitHub](https://github.com/Job-Ad-Research-at-QSB-LUC/JAAT)
""")
with gr.Row():
with gr.Column(scale=2):
job_title = gr.Textbox(
label="Job Title (used by TitleMatch)",
placeholder='e.g. "Software Engineer" or "Registered Nurse"',
lines=1,
)
job_text = gr.Textbox(
label="Full Job Advertisement Text",
placeholder="Paste the full text of a job posting here, then click Analyze...",
lines=12,
)
mode = gr.Radio(
choices=["Summary", "Line-by-line"],
value="Summary",
label="Display mode",
info="Summary = unique matches in tables. Line-by-line = each "
"sentence shown with the Tasks / Skills / AI Concepts it "
"triggered.",
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
pipeline_status = gr.Markdown("Pipeline status will appear here.")
gr.Markdown("---")
gr.Markdown("### Results")
with gr.Row():
with gr.Column():
gr.Markdown("**FirmExtract**")
firm_output = gr.Markdown()
with gr.Column():
gr.Markdown("**WageExtract**")
wage_output = gr.Markdown()
with gr.Column():
gr.Markdown("**TitleMatch**")
title_output = gr.Markdown()
with gr.Accordion("TaskMatch β€” O*NET Tasks", open=True):
task_output = gr.Markdown()
with gr.Accordion("SkillMatch β€” ESCO Skills", open=True):
skill_output = gr.Markdown()
with gr.Accordion("AIMatch β€” AI Concepts", open=True):
ai_output = gr.Markdown()
with gr.Accordion("JobTag β€” Job Attributes", open=True):
tag_output = gr.Markdown()
with gr.Accordion(
"Line-by-line Attribution (Tasks / Skills / AI)",
open=True,
):
line_output = gr.Markdown(
"Select **Line-by-line** mode and click Analyze to see each "
"sentence of the ad paired with the O*NET tasks, ESCO skills, "
"and AI concepts it triggered."
)
gr.Markdown(CITATION)
analyze_btn.click(
fn=analyze,
inputs=[job_title, job_text, mode],
outputs=[pipeline_status, firm_output, wage_output, title_output, task_output, skill_output, ai_output, tag_output, line_output],
)
demo.launch(server_name="0.0.0.0", server_port=7860)