job-parser-space / inference.py
Rithankoushik's picture
Update inference.py
18be1bf verified
import torch
import re
import time
import json
import json5
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Model paths
base_model_id = "Qwen/Qwen3-0.6B"
lora_model_id = "Rithankoushik/Qwen-0.6-Job-parser-Model"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id # ✅ critical fix
# Load model + LoRA
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, lora_model_id, device_map="auto")
model = model.merge_and_unload()
model.eval()
def extract_and_clean_json(text):
"""Extract JSON from LLM output, even if extra text is present."""
matches = re.findall(r"\{[\s\S]*\}", text)
if not matches:
return None
json_str = matches[0] # take first JSON
json_str = json_str.replace("None", "null")
json_str = json_str.replace("True", "true").replace("False", "false")
json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)
try:
return json5.loads(json_str)
except Exception as e:
print(f"JSON parse error: {e}")
return None
def infer_from_text(jd_text: str):
"""Runs inference on a job description."""
start_time = time.time()
system_prompt = "Extract structured information from the following job description and return it as JSON."
user_prompt = f"Job Description:\n{jd_text}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# ✅ safer way
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
raw_inputs = tokenizer(prompt, return_tensors="pt")
device = model.device
inputs = {k: v.to(device) for k, v in raw_inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
temperature=0,
pad_token_id=tokenizer.pad_token_id
)
gen_tokens = out[0][inputs["input_ids"].shape[1]:]
response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
duration = round(time.time() - start_time, 2)
parsed = extract_and_clean_json(response_text)
if parsed is not None:
return json.dumps(parsed, indent=2), duration
return response_text, duration