coderound / backend /src /routers /matching.py
ketannnn's picture
feat: implement matching router and stage 2 neural reranking logic
88462d6
import uuid
import json
import logging
import redis.asyncio as redis
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
logger = logging.getLogger(__name__)
from ..database import get_db
from ..config import get_settings
from ..models.jd import JobDescription
from ..models.candidate import Candidate
from ..models.match_result import MatchResult
from ..schemas.match import (
MatchResponse, MatchedCandidate, ComponentScores, GapItem,
CandidateDetailResponse, ReRankRequest,
)
from ..matching.stage1 import stage1_retrieve
from ..matching.stage2 import stage2_rerank
from ..matching.llm_explainer import generate_explanation
from ..matching.scorer import rerank_with_weights
router = APIRouter()
def _get_qdrant(request: Request):
return request.app.state.qdrant
async def _load_jd(jd_id: uuid.UUID, db: AsyncSession) -> JobDescription:
result = await db.execute(select(JobDescription).where(JobDescription.id == jd_id))
jd = result.scalar_one_or_none()
if not jd:
raise HTTPException(status_code=404, detail="JD not found")
if jd.status == "processing":
raise HTTPException(status_code=202, detail="JD is still being processed, try again shortly")
return jd
def _build_jd_dict(jd: JobDescription) -> dict:
return {
"id": str(jd.id), "title": jd.title, "raw_text": jd.raw_text,
"required_skills": jd.required_skills or [], "min_yoe": jd.min_yoe,
"max_yoe": jd.max_yoe, "role_type": jd.role_type,
"engineer_type": jd.engineer_type, "location": jd.location,
"remote_allowed": jd.remote_allowed,
"custom_weights": jd.custom_weights or {},
}
def _to_matched_candidate(item: dict, rank: int) -> MatchedCandidate:
return MatchedCandidate(
candidate_id=uuid.UUID(item["candidate_id"]),
rank=rank,
name=item.get("name"),
email=item.get("email"),
role_type=item.get("role_type"),
engineer_type=item.get("engineer_type"),
years_of_experience=item.get("years_of_experience"),
most_recent_company=item.get("most_recent_company"),
parsed_summary=item.get("parsed_summary"),
programming_languages=item.get("programming_languages") or [],
growth_velocity=item.get("growth_velocity", 0.5),
stage1_score=item.get("stage1_score", 0),
stage2_score=item.get("stage2_score"),
final_score=item.get("final_score", 0),
component_scores=ComponentScores(**(item.get("component_scores") or {})),
gaps=[GapItem(**g) for g in item.get("gaps", [])],
)
@router.post("/{jd_id}", response_model=MatchResponse)
async def trigger_match(
jd_id: uuid.UUID,
request: Request,
session_id: uuid.UUID | None = Query(None, description="Candidate session to match against"),
stage1_top_k: int = Query(100, description="How many candidates to retrieve from vector DB (Stage 1)"),
stage2_top_k: int = Query(100, description="How many Stage 1 candidates to pass to the neural reranker (Stage 2)"),
db: AsyncSession = Depends(get_db),
):
try:
jd = await _load_jd(jd_id, db)
qdrant = _get_qdrant(request)
jd_dict = _build_jd_dict(jd)
sid_str = str(session_id) if session_id else None
# Stage 1: Retrieve top-K from vector DB
logger.info(f"[trigger_match] JD={jd_id} | Stage 1 starting (top_k={stage1_top_k})")
shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str, top_k=stage1_top_k)
logger.info(f"[trigger_match] JD={jd_id} | Stage 1 complete — {len(shortlist)} candidates retrieved")
# Stage 2: Neural cross-encoder reranker
rerank_input = shortlist[:stage2_top_k]
logger.info(f"[trigger_match] JD={jd_id} | Stage 2 starting (reranking {len(rerank_input)} candidates)")
final_ranked = await stage2_rerank(jd_dict, rerank_input)
logger.info(f"[trigger_match] JD={jd_id} | Stage 2 complete — {len(final_ranked)} candidates ranked")
# Persist results to DB
await db.execute(
delete(MatchResult).where(
MatchResult.jd_id == jd_id,
MatchResult.session_id == session_id if session_id else MatchResult.session_id.is_(None),
)
)
inserted_mrs = []
for i, item in enumerate(final_ranked):
mr = MatchResult(
id=uuid.uuid4(), jd_id=jd_id,
candidate_id=uuid.UUID(item["candidate_id"]),
session_id=session_id,
rank=i + 1,
stage1_score=item.get("stage1_score", 0),
stage2_score=item.get("stage2_score"),
final_score=item.get("final_score", 0),
component_scores=item.get("component_scores", {}),
gaps=item.get("gaps", []),
)
db.add(mr)
inserted_mrs.append(mr)
await db.commit()
logger.info(f"[trigger_match] JD={jd_id} | {len(inserted_mrs)} match results saved to DB")
except Exception as exc:
# Log the FULL traceback so it appears in HF container logs
logger.exception(f"[trigger_match] FATAL — JD={jd_id} session={session_id} | {type(exc).__name__}: {exc}")
await db.rollback()
raise
from ..workers.explain import generate_top_explanations
# Pre-generate LLM explanations async for top 20 in background
top_20_ids = [str(mr.id) for mr in inserted_mrs[:20]]
if top_20_ids:
generate_top_explanations.delay(top_20_ids)
results = [_to_matched_candidate(item, i + 1) for i, item in enumerate(final_ranked)]
return MatchResponse(
jd_id=jd_id, jd_title=jd.title,
jd_quality=jd.jd_quality or {},
total_matched=len(results), results=results,
weights_used={"semantic": 0.20, "skill": 0.35, "yoe": 0.15, "company": 0.10, "growth": 0.10, "education": 0.10},
session_id=session_id,
)
@router.get("/{jd_id}", response_model=MatchResponse)
async def get_match_results(
jd_id: uuid.UUID,
session_id: uuid.UUID | None = Query(None),
db: AsyncSession = Depends(get_db),
):
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
cache_key = f"match_v2:{jd_id}:{session_id or 'none'}"
try:
cached = await r.get(cache_key)
if cached:
return json.loads(cached)
except Exception:
pass
jd = await _load_jd(jd_id, db)
q = (
select(MatchResult, Candidate)
.join(Candidate, MatchResult.candidate_id == Candidate.id)
.where(MatchResult.jd_id == jd_id)
)
if session_id:
q = q.where(MatchResult.session_id == session_id)
else:
q = q.where(MatchResult.session_id.is_(None))
q = q.order_by(MatchResult.rank)
result = await db.execute(q)
rows = result.all()
if not rows:
return MatchResponse(
jd_id=jd_id,
jd_title=jd.title,
jd_quality=jd.jd_quality or {},
weights_used=jd.custom_weights or {},
total_matched=0,
results=[],
session_id=session_id,
)
results = []
for mr, cand in rows:
item = {
"candidate_id": str(cand.id), "name": cand.name, "email": cand.email,
"role_type": cand.role_type, "engineer_type": cand.engineer_type,
"years_of_experience": cand.years_of_experience,
"most_recent_company": cand.most_recent_company,
"parsed_summary": cand.parsed_summary,
"programming_languages": cand.programming_languages or [],
"growth_velocity": cand.growth_velocity,
"stage1_score": mr.stage1_score, "stage2_score": mr.stage2_score,
"final_score": mr.final_score,
"component_scores": mr.component_scores or {}, "gaps": mr.gaps or [],
}
results.append(item)
# Automatically transform the database RRF fallback score into the correct % parameter scale
reranked = rerank_with_weights(results, jd.custom_weights or {})
final_results = [_to_matched_candidate(item, item["rank"]) for item in reranked]
return MatchResponse(
jd_id=jd_id, jd_title=jd.title, jd_quality=jd.jd_quality or {}, weights_used=jd.custom_weights or {},
total_matched=len(results), results=final_results, session_id=session_id,
)
@router.post("/{jd_id}/rerank", response_model=MatchResponse)
async def rerank_results(
jd_id: uuid.UUID,
payload: ReRankRequest,
session_id: uuid.UUID | None = Query(None),
db: AsyncSession = Depends(get_db),
):
jd = await _load_jd(jd_id, db)
# Save custom weights into the database asynchronously!
jd.custom_weights = payload.weights
await db.commit()
q = (
select(MatchResult, Candidate)
.join(Candidate, MatchResult.candidate_id == Candidate.id)
.where(MatchResult.jd_id == jd_id)
)
if session_id:
q = q.where(MatchResult.session_id == session_id)
else:
q = q.where(MatchResult.session_id.is_(None))
q = q.order_by(MatchResult.rank)
result = await db.execute(q)
rows = result.all()
if not rows:
raise HTTPException(status_code=404, detail="No match results found.")
items = [
{
"candidate_id": str(cand.id), "name": cand.name, "email": cand.email,
"role_type": cand.role_type, "engineer_type": cand.engineer_type,
"years_of_experience": cand.years_of_experience,
"most_recent_company": cand.most_recent_company,
"parsed_summary": cand.parsed_summary,
"programming_languages": cand.programming_languages or [],
"growth_velocity": cand.growth_velocity,
"stage1_score": mr.stage1_score, "stage2_score": mr.stage2_score,
"final_score": mr.final_score,
"component_scores": mr.component_scores or {}, "gaps": mr.gaps or [],
}
for mr, cand in rows
]
reranked = rerank_with_weights(items, payload.weights)
results = [_to_matched_candidate(item, item["rank"]) for item in reranked]
return MatchResponse(
jd_id=jd_id, jd_title=jd.title, jd_quality=jd.jd_quality or {},
total_matched=len(results), results=results,
weights_used=payload.weights, session_id=session_id,
)
@router.post("/{jd_id}/candidates/{candidate_id}/explain")
async def trigger_explanation(
jd_id: uuid.UUID,
candidate_id: uuid.UUID,
session_id: uuid.UUID | None = Query(None),
db: AsyncSession = Depends(get_db),
):
q = select(MatchResult).where(MatchResult.jd_id == jd_id, MatchResult.candidate_id == candidate_id)
if session_id:
q = q.where(MatchResult.session_id == session_id)
mr_result = await db.execute(q)
mr = mr_result.scalar_one_or_none()
if not mr:
raise HTTPException(status_code=404, detail="Match result not found")
generate_top_explanations.delay([str(mr.id)])
return {"status": "queued"}
@router.get("/{jd_id}/{candidate_id}", response_model=CandidateDetailResponse)
async def get_candidate_detail(
jd_id: uuid.UUID,
candidate_id: uuid.UUID,
session_id: uuid.UUID | None = Query(None),
db: AsyncSession = Depends(get_db),
):
jd = await _load_jd(jd_id, db)
q = select(MatchResult).where(MatchResult.jd_id == jd_id, MatchResult.candidate_id == candidate_id)
if session_id:
q = q.where(MatchResult.session_id == session_id)
mr_result = await db.execute(q)
mr = mr_result.scalar_one_or_none()
if not mr:
raise HTTPException(status_code=404, detail="Match result not found")
cand_result = await db.execute(select(Candidate).where(Candidate.id == candidate_id))
cand = cand_result.scalar_one_or_none()
if not cand:
raise HTTPException(status_code=404, detail="Candidate not found")
return CandidateDetailResponse(
jd_id=jd_id, candidate_id=candidate_id, rank=mr.rank,
final_score=mr.final_score,
component_scores=mr.component_scores or {},
gaps=mr.gaps or [],
explanation=mr.explanation,
candidate={
"name": cand.name, "email": cand.email, "role_type": cand.role_type,
"engineer_type": cand.engineer_type, "years_of_experience": cand.years_of_experience,
"most_recent_company": cand.most_recent_company, "parsed_summary": cand.parsed_summary,
"parsed_skills": cand.parsed_skills, "parsed_work_experience": cand.parsed_work_experience or [],
"programming_languages": cand.programming_languages or [],
"backend_frameworks": cand.backend_frameworks or [],
"gen_ai_experience": cand.gen_ai_experience, "growth_velocity": cand.growth_velocity,
"looking_for": cand.looking_for, "open_to_working_at": cand.open_to_working_at,
"is_actively_or_passively_looking": cand.is_actively_or_passively_looking,
"most_recent_company_is_funded": cand.most_recent_company_is_funded,
"most_recent_company_is_product_company": cand.most_recent_company_is_product_company,
"most_recent_company_total_funding": cand.most_recent_company_total_funding,
},
jd={
"title": jd.title, "required_skills": jd.required_skills or [],
"min_yoe": jd.min_yoe, "role_type": jd.role_type,
"engineer_type": jd.engineer_type, "location": jd.location,
},
)