resume-llm-api / src /inference.py
mhr-212's picture
Upload folder using huggingface_hub
7e0c689 verified
import torch
import json
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Union
import re
import os
class ResumeInferenceEngine:
"""Inference engine for resume extraction and matching"""
def __init__(self, model_path: str = "models/checkpoints/final"):
"""Load fine-tuned model and tokenizer"""
print(f"Loading model from {model_path}...")
# CPU-only environments (common on Windows laptops) can hit PEFT/accelerate
# offload edge-cases when using device_map="auto". Prefer a simple CPU load.
# RAM Optimization: Force half-precision (bfloat16 is best on CPU)
# 16GB is tight for float32 loading, so we use bfloat16 to cut RAM in half.
dtype = torch.bfloat16 if torch.cuda.is_available() or hasattr(torch, 'bfloat16') else torch.float32
device_map = "auto" if torch.cuda.is_available() else None
low_cpu_mem_usage = True # Always use low_cpu_mem_usage to prevent spikes
adapter_config_path = os.path.join(model_path, "adapter_config.json")
is_adapter = os.path.exists(adapter_config_path)
# Prefer tokenizer saved alongside adapter/model (the notebook saves tokenizer to final/)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if is_adapter:
from peft import PeftModel
with open(adapter_config_path, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
base_model_name = adapter_cfg.get("base_model_name_or_path") or adapter_cfg.get("base_model") or "microsoft/phi-2"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
dtype=dtype,
device_map=device_map,
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=True,
)
# DEBUG: Inspect the safetensors file before loading
adapter_file = os.path.join(model_path, "adapter_model.safetensors")
if os.path.exists(adapter_file):
size = os.path.getsize(adapter_file)
print(f"DEBUG: adapter_model.safetensors size: {size} bytes")
with open(adapter_file, "rb") as f:
header = f.read(100)
print(f"DEBUG: adapter_model.safetensors header: {header}")
else:
print(f"DEBUG: adapter_model.safetensors NOT FOUND at {adapter_file}")
self.model = PeftModel.from_pretrained(base_model, model_path)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=dtype,
device_map=device_map,
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=True,
)
self.model.eval()
def extract_resume(self, resume_text: str) -> Dict:
"""Extract structured information from resume"""
prompt = f"""Instruction: Extract structured information from the resume. Return valid JSON with fields: name, email, phone, skills, experience, education, certifications.
Input:
{resume_text}
Output:"""
# Speed Optimization: Drastic reduction to 200 to GUARANTEE finish before Heroku 30s timeout
output = self._generate(prompt, max_length=200)
return self._parse_json_output(output)
def match_resume_to_job(self, resume_text: str, job_description: str) -> Dict:
"""Match resume to job description"""
prompt = f"""Instruction: Compare the resume against the job description and provide a match score (0-100) with reasoning. Return valid JSON with fields: match_score, matching_skills, missing_skills, recommendation.
Input:
Resume:
{resume_text}
Job Description:
{job_description}
Output:"""
# Use a lower temperature to improve format adherence.
output = self._generate(prompt, max_length=256, temperature=0.3)
return self._parse_json_output(output)
def _generate(self, prompt: str, max_length: int = 512, temperature: float = 0.7) -> str:
"""Generate text from prompt"""
# When using device_map="auto", pick the device of the first parameter.
input_device = next(iter(self.model.parameters())).device
tokenized = self.tokenizer(prompt, return_tensors="pt")
tokenized = {k: v.to(input_device) for k, v in tokenized.items()}
input_len = tokenized["input_ids"].shape[1]
# Interpret max_length as a generation budget (max_new_tokens) for backward compat.
# Cap at 150 to be safe for CPU inference latency.
max_new_tokens = max(32, min(150, int(max_length)))
with torch.inference_mode():
sequences = self.model.generate(
**tokenized,
max_new_tokens=max_new_tokens,
min_new_tokens=8,
temperature=temperature,
top_p=0.95,
num_beams=1,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode ONLY the generated continuation; avoids returning an empty string when the
# prompt already contains the delimiter text (e.g., "Output:").
gen_tokens = sequences[0][input_len:]
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
if gen_text:
return gen_text
# Fallback: full decode so callers can see what happened.
full_text = self.tokenizer.decode(sequences[0], skip_special_tokens=True)
return full_text.strip()
def _parse_json_output(self, output: str) -> Dict:
"""Extract JSON from model output"""
def _split_skills(v: Union[str, List[str], None]) -> List[str]:
if v is None:
return []
if isinstance(v, list):
return [str(s).strip() for s in v if str(s).strip()]
v = str(v).strip()
if not v or v.lower() in {"none", "n/a", "na"}:
return []
return [s.strip() for s in v.split(",") if s.strip()]
def _normalize(d: Dict) -> Dict:
if not isinstance(d, dict):
return {"raw_output": output}
# Normalize match_score to 0-100
if "match_score" in d:
try:
score = d["match_score"]
if isinstance(score, str):
score = float(re.findall(r"[0-9]*\.?[0-9]+", score)[0])
else:
score = float(score)
if score <= 1.0:
score *= 100.0
d["match_score"] = score
except Exception:
pass
# Normalize skills fields to lists
if "matching_skills" in d:
d["matching_skills"] = _split_skills(d.get("matching_skills"))
if "missing_skills" in d:
d["missing_skills"] = _split_skills(d.get("missing_skills"))
# Preserve raw output for debugging
d.setdefault("raw_output", output)
return d
try:
# Try to find JSON in the output
json_match = re.search(r'\{.*\}', output, re.DOTALL)
if json_match:
json_str = json_match.group(0)
return _normalize(json.loads(json_str))
except json.JSONDecodeError:
pass
# Fallback: parse simple key:value lines (common when the model doesn't emit JSON).
# Example:
# match_score: 0.85
# matching_skills: Python, TensorFlow
if isinstance(output, str):
kv = {}
for raw_line in output.splitlines():
line = raw_line.strip()
if not line or ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip().strip('"').strip("'").lower()
value = value.strip().strip('"').strip("'")
if not key:
continue
kv[key] = value
if kv:
# Normalize known fields
if "match_score" in kv:
try:
score = float(re.findall(r"[0-9]*\.?[0-9]+", kv["match_score"])[0])
if score <= 1.0:
score *= 100.0
kv["match_score"] = score
except Exception:
pass
if "matching_skills" in kv:
kv["matching_skills"] = _split_skills(kv["matching_skills"])
if "missing_skills" in kv:
kv["missing_skills"] = _split_skills(kv["missing_skills"])
# Keep a copy of the original raw output for debugging
kv["raw_output"] = output
return kv
# Fallback: try to parse a match score from plain text.
m = re.search(r"match\s*score\s*[:=]\s*([0-9]*\.?[0-9]+)", output or "", flags=re.IGNORECASE)
if m:
score = float(m.group(1))
if score <= 1.0:
score *= 100.0
return {"match_score": score, "raw_output": output}
# Return structured response if parsing fails
return {"error": "Failed to parse output", "raw_output": output}
def batch_extract(self, resumes: List[str]) -> List[Dict]:
"""Extract from multiple resumes"""
results = []
for i, resume in enumerate(resumes):
print(f"Processing resume {i+1}/{len(resumes)}...")
results.append(self.extract_resume(resume))
return results
def batch_match(self, resume_pairs: List[tuple]) -> List[Dict]:
"""Match multiple resume-job pairs"""
results = []
for i, (resume, job) in enumerate(resume_pairs):
print(f"Processing pair {i+1}/{len(resume_pairs)}...")
results.append(self.match_resume_to_job(resume, job))
return results
# Flask API for serving predictions
def create_api(model_path: str = "models/checkpoints/final"):
"""Create Flask API for inference"""
from flask import Flask, request, jsonify
from flask_cors import CORS
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
engine = ResumeInferenceEngine(model_path)
@app.route("/extract", methods=["POST"])
def extract():
"""Extract information from resume"""
data = request.json
resume = data.get("resume", "")
if not resume:
return jsonify({"error": "Resume text required"}), 400
result = engine.extract_resume(resume)
return jsonify(result)
@app.route("/match", methods=["POST"])
def match():
"""Match resume to job description"""
data = request.json
resume = data.get("resume", "")
job = data.get("job_description", "")
if not resume or not job:
return jsonify({"error": "Resume and job description required"}), 400
result = engine.match_resume_to_job(resume, job)
return jsonify(result)
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "healthy"})
return app
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="cli", help="Mode: cli, api, or batch")
parser.add_argument("--model-path", default="models/checkpoints/final", help="Path to model")
parser.add_argument("--task", choices=["extract", "match"], default="extract")
parser.add_argument("--resume-file", help="Path to resume file")
parser.add_argument("--job-file", help="Path to job description file")
parser.add_argument("--port", type=int, default=5000, help="API port")
args = parser.parse_args()
engine = ResumeInferenceEngine(args.model_path)
if args.mode == "cli":
if args.task == "extract":
with open(args.resume_file) as f:
resume = f.read()
result = engine.extract_resume(resume)
print(json.dumps(result, indent=2))
elif args.task == "match":
with open(args.resume_file) as f:
resume = f.read()
with open(args.job_file) as f:
job = f.read()
result = engine.match_resume_to_job(resume, job)
print(json.dumps(result, indent=2))
elif args.mode == "api":
app = create_api(args.model_path)
print(f"Starting API on port {args.port}...")
app.run(host="0.0.0.0", port=args.port, debug=False)
if __name__ == "__main__":
main()