ketannnn commited on
Commit
12fa3c2
·
1 Parent(s): 5655f74

fix: resolve asyncpg ssl connection errors and Qdrant strict payload filtering

Browse files
backend/alembic/env.py CHANGED
@@ -12,8 +12,12 @@ if config.config_file_name is not None:
12
  import sys, os
13
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
 
 
 
 
 
15
  from src.database import Base
16
- from src.models import JobDescription, Candidate, MatchResult
17
  from src.config import get_settings
18
 
19
  target_metadata = Base.metadata
@@ -22,6 +26,10 @@ target_metadata = Base.metadata
22
  def _make_async_url(url: str) -> str:
23
  url = re.sub(r"^postgresql:", "postgresql+asyncpg:", url)
24
  url = re.sub(r"[?&]channel_binding=require", "", url)
 
 
 
 
25
  return url
26
 
27
 
@@ -45,7 +53,9 @@ def do_run_migrations(connection):
45
 
46
  async def run_async_migrations() -> None:
47
  settings = get_settings()
48
- connectable = create_async_engine(_make_async_url(settings.database_url), poolclass=pool.NullPool)
 
 
49
  async with connectable.connect() as connection:
50
  await connection.run_sync(do_run_migrations)
51
  await connectable.dispose()
 
12
  import sys, os
13
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
 
15
+ # Load .env before importing settings
16
+ from dotenv import load_dotenv
17
+ load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
18
+
19
  from src.database import Base
20
+ from src.models import JobDescription, Candidate, MatchResult, Session
21
  from src.config import get_settings
22
 
23
  target_metadata = Base.metadata
 
26
  def _make_async_url(url: str) -> str:
27
  url = re.sub(r"^postgresql:", "postgresql+asyncpg:", url)
28
  url = re.sub(r"[?&]channel_binding=require", "", url)
29
+ url = re.sub(r"[?&]sslmode=[^&]*", "", url)
30
+ url = re.sub(r"[?&]connect_timeout=[^&]*", "", url)
31
+ # clean trailing ? or &
32
+ url = re.sub(r"[?&]$", "", url)
33
  return url
34
 
35
 
 
53
 
54
  async def run_async_migrations() -> None:
55
  settings = get_settings()
56
+ from src.database import _make_async_url
57
+ db_url, connect_args = _make_async_url(settings.database_url)
58
+ connectable = create_async_engine(db_url, poolclass=pool.NullPool, connect_args=connect_args)
59
  async with connectable.connect() as connection:
60
  await connection.run_sync(do_run_migrations)
61
  await connectable.dispose()
backend/main.py CHANGED
@@ -1,18 +1,21 @@
1
  import os
 
2
  from contextlib import asynccontextmanager
3
- from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  from qdrant_client import QdrantClient
6
- from qdrant_client.models import Distance, VectorParams
7
 
8
  from src.config import get_settings
9
- from src.database import engine
10
- from src.models import JobDescription, Candidate, MatchResult
11
- from src.routers import jds, candidates, matching
12
 
 
13
  settings = get_settings()
14
 
15
  _qdrant_client: QdrantClient | None = None
 
16
 
17
 
18
  def get_qdrant() -> QdrantClient:
@@ -21,24 +24,47 @@ def get_qdrant() -> QdrantClient:
21
 
22
  @asynccontextmanager
23
  async def lifespan(app: FastAPI):
24
- global _qdrant_client
25
  _qdrant_client = QdrantClient(url=settings.qdrant_url, api_key=settings.qdrant_api_key)
26
 
27
- existing = [c.name for c in _qdrant_client.get_collections().collections]
28
- if settings.collection_name not in existing:
29
- _qdrant_client.create_collection(
30
- collection_name=settings.collection_name,
31
- vectors_config=VectorParams(size=settings.vector_size, distance=Distance.COSINE),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
  app.state.qdrant = _qdrant_client
 
35
  yield
36
  _qdrant_client.close()
37
 
38
 
39
  app = FastAPI(
40
  title="TalentPulse — AI Candidate Matching",
41
- description="Two-stage retrieval + reranking pipeline for matching JDs against 100K+ candidates",
42
  version="1.0.0",
43
  lifespan=lifespan,
44
  )
@@ -51,11 +77,22 @@ app.add_middleware(
51
  allow_headers=["*"],
52
  )
53
 
 
54
  app.include_router(jds.router, prefix="/api/jds", tags=["Job Descriptions"])
55
  app.include_router(candidates.router, prefix="/api/candidates", tags=["Candidates"])
56
  app.include_router(matching.router, prefix="/api/match", tags=["Matching"])
57
 
58
 
59
  @app.get("/health")
60
- async def health():
61
- return {"status": "ok", "version": "1.0.0"}
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import logging
3
  from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
  from qdrant_client import QdrantClient
8
+ from qdrant_client.models import Distance, VectorParams, PayloadSchemaType
9
 
10
  from src.config import get_settings
11
+ from src.models import JobDescription, Candidate, MatchResult, Session
12
+ from src.routers import jds, candidates, matching, sessions
 
13
 
14
+ logger = logging.getLogger(__name__)
15
  settings = get_settings()
16
 
17
  _qdrant_client: QdrantClient | None = None
18
+ _qdrant_ready: bool = False
19
 
20
 
21
  def get_qdrant() -> QdrantClient:
 
24
 
25
  @asynccontextmanager
26
  async def lifespan(app: FastAPI):
27
+ global _qdrant_client, _qdrant_ready
28
  _qdrant_client = QdrantClient(url=settings.qdrant_url, api_key=settings.qdrant_api_key)
29
 
30
+ try:
31
+ existing = [c.name for c in _qdrant_client.get_collections().collections]
32
+ if settings.collection_name not in existing:
33
+ _qdrant_client.create_collection(
34
+ collection_name=settings.collection_name,
35
+ vectors_config=VectorParams(size=settings.vector_size, distance=Distance.COSINE),
36
+ )
37
+ # Create indexing for the session_id to allow fast filtering
38
+ _qdrant_client.create_payload_index(
39
+ collection_name=settings.collection_name,
40
+ field_name="session_id",
41
+ field_schema=PayloadSchemaType.UUID,
42
+ )
43
+ # Create indexing for years_of_experience for range filtering
44
+ _qdrant_client.create_payload_index(
45
+ collection_name=settings.collection_name,
46
+ field_name="years_of_experience",
47
+ field_schema=PayloadSchemaType.FLOAT,
48
+ )
49
+ _qdrant_ready = True
50
+ logger.info("Qdrant connected — collection '%s' ready", settings.collection_name)
51
+ except Exception as exc:
52
+ _qdrant_ready = False
53
+ logger.warning(
54
+ "Qdrant unavailable at startup (%s). "
55
+ "The API will start but vector search will fail until Qdrant is reachable.",
56
+ exc,
57
  )
58
 
59
  app.state.qdrant = _qdrant_client
60
+ app.state.qdrant_ready = _qdrant_ready
61
  yield
62
  _qdrant_client.close()
63
 
64
 
65
  app = FastAPI(
66
  title="TalentPulse — AI Candidate Matching",
67
+ description="Two-stage retrieval + reranking pipeline for matching JDs against candidate sessions",
68
  version="1.0.0",
69
  lifespan=lifespan,
70
  )
 
77
  allow_headers=["*"],
78
  )
79
 
80
+ app.include_router(sessions.router, prefix="/api/sessions", tags=["Sessions"])
81
  app.include_router(jds.router, prefix="/api/jds", tags=["Job Descriptions"])
82
  app.include_router(candidates.router, prefix="/api/candidates", tags=["Candidates"])
83
  app.include_router(matching.router, prefix="/api/match", tags=["Matching"])
84
 
85
 
86
  @app.get("/health")
87
+ async def health(request: "Request"):
88
+ qdrant_ok = getattr(request.app.state, "qdrant_ready", False)
89
+ return {
90
+ "status": "ok",
91
+ "version": "1.0.0",
92
+ "qdrant": "connected" if qdrant_ok else "unavailable",
93
+ }
94
+
95
+
96
+ static_dir = os.path.join(os.path.dirname(__file__), "static")
97
+ if os.path.isdir(static_dir):
98
+ app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
backend/src/database.py CHANGED
@@ -5,14 +5,39 @@ from sqlalchemy.orm import DeclarativeBase
5
  from .config import get_settings
6
 
7
 
8
- def _make_async_url(url: str) -> str:
9
- url = re.sub(r"^postgresql:", "postgresql+asyncpg:", url)
10
- url = re.sub(r"[?&]channel_binding=require", "", url)
11
- return url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  settings = get_settings()
15
- engine = create_async_engine(_make_async_url(settings.database_url), echo=False, pool_pre_ping=True)
 
 
 
 
 
 
16
  AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
17
 
18
 
@@ -23,3 +48,4 @@ class Base(DeclarativeBase):
23
  async def get_db() -> AsyncGenerator[AsyncSession, None]:
24
  async with AsyncSessionLocal() as session:
25
  yield session
 
 
5
  from .config import get_settings
6
 
7
 
8
+ def _make_async_url(url: str) -> tuple[str, dict]:
9
+ """Convert a standard postgres:// URL to asyncpg-compatible form.
10
+
11
+ asyncpg does NOT accept sslmode or channel_binding as URL query params.
12
+ Strip them and return connect_args with ssl=True when sslmode was present.
13
+ """
14
+ needs_ssl = bool(re.search(r"[?&]sslmode=", url))
15
+ # Switch scheme
16
+ url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url)
17
+ # Remove unsupported query params
18
+ for param in ("sslmode", "channel_binding"):
19
+ url = re.sub(rf"[?&]{param}=[^&]*", "", url)
20
+ # Clean up trailing ? or & left behind
21
+ url = re.sub(r"\?$", "", url)
22
+ url = re.sub(r"&$", "", url)
23
+ connect_args: dict = {}
24
+ if needs_ssl:
25
+ import ssl as _ssl
26
+ ctx = _ssl.create_default_context()
27
+ ctx.check_hostname = False
28
+ ctx.verify_mode = _ssl.CERT_NONE
29
+ connect_args["ssl"] = ctx
30
+ return url, connect_args
31
 
32
 
33
  settings = get_settings()
34
+ _db_url, _connect_args = _make_async_url(settings.database_url)
35
+ engine = create_async_engine(
36
+ _db_url,
37
+ echo=False,
38
+ pool_pre_ping=True,
39
+ connect_args=_connect_args,
40
+ )
41
  AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
42
 
43
 
 
48
  async def get_db() -> AsyncGenerator[AsyncSession, None]:
49
  async with AsyncSessionLocal() as session:
50
  yield session
51
+
backend/src/matching/stage1.py CHANGED
@@ -2,39 +2,26 @@ from typing import Any
2
  from qdrant_client import QdrantClient
3
  from qdrant_client.models import Filter, FieldCondition, MatchValue, Range
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
- from sqlalchemy import select, or_
6
 
7
  from ..config import get_settings
8
  from ..models.candidate import Candidate
9
  from ..ml.embedder import embed_query
10
- from ..ml.feature_builder import (
11
- skill_jaccard,
12
- yoe_match,
13
- company_quality_signal,
14
- education_match,
15
- )
16
 
17
 
18
  DEFAULT_WEIGHTS = {
19
- "semantic": 0.20,
20
- "skill": 0.35,
21
- "yoe": 0.15,
22
- "company": 0.10,
23
- "growth": 0.10,
24
- "education": 0.10,
25
  }
26
 
27
 
28
- def _build_qdrant_filter(jd: dict) -> Filter | None:
29
  conditions = []
30
- if jd.get("role_type"):
31
- conditions.append(
32
- FieldCondition(key="role_type", match=MatchValue(value=jd["role_type"]))
33
- )
34
  if jd.get("min_yoe") is not None:
35
- conditions.append(
36
- FieldCondition(key="years_of_experience", range=Range(gte=max(0, jd["min_yoe"] - 2)))
37
- )
38
  if not conditions:
39
  return None
40
  return Filter(must=conditions)
@@ -44,6 +31,7 @@ async def stage1_retrieve(
44
  jd: dict,
45
  db: AsyncSession,
46
  qdrant: QdrantClient,
 
47
  top_k: int = 200,
48
  weights: dict | None = None,
49
  ) -> list[dict[str, Any]]:
@@ -53,7 +41,7 @@ async def stage1_retrieve(
53
  jd_text = f"{jd.get('title', '')} {jd.get('raw_text', '')}"
54
  query_vector = embed_query(jd_text)
55
 
56
- qdrant_filter = _build_qdrant_filter(jd)
57
  search_results = qdrant.search(
58
  collection_name=settings.collection_name,
59
  query_vector=query_vector.tolist(),
@@ -68,9 +56,7 @@ async def stage1_retrieve(
68
  qdrant_ids = [r.id for r in search_results]
69
  score_by_qdrant_id = {r.id: float(r.score) for r in search_results}
70
 
71
- result = await db.execute(
72
- select(Candidate).where(Candidate.qdrant_id.in_(qdrant_ids))
73
- )
74
  candidates = {c.qdrant_id: c for c in result.scalars().all()}
75
 
76
  jd_skills = jd.get("required_skills") or []
@@ -84,11 +70,8 @@ async def stage1_retrieve(
84
  continue
85
 
86
  cosine_sim = score_by_qdrant_id[qid]
87
-
88
  all_cand_skills = (
89
- (cand.programming_languages or [])
90
- + (cand.backend_frameworks or [])
91
- + (cand.frontend_technologies or [])
92
  )
93
  if cand.parsed_skills:
94
  all_cand_skills.extend([s.strip() for s in cand.parsed_skills.split(",") if s.strip()])
@@ -97,45 +80,35 @@ async def stage1_retrieve(
97
  "semantic": cosine_sim,
98
  "skill": skill_jaccard(jd_skills, all_cand_skills),
99
  "yoe": yoe_match(min_yoe, max_yoe, cand.years_of_experience),
100
- "company": company_quality_signal(
101
- {
102
- "most_recent_company_is_funded": cand.most_recent_company_is_funded,
103
- "most_recent_company_is_product_company": cand.most_recent_company_is_product_company,
104
- "most_recent_company_total_funding": cand.most_recent_company_total_funding,
105
- }
106
- ),
107
  "growth": float(cand.growth_velocity or 0.5),
108
- "education": education_match(
109
- {
110
- "degree": cand.degree,
111
- "education_status": cand.education_status,
112
- }
113
- ),
114
  }
115
 
116
  total = sum(w.get(k, 0) * v for k, v in components.items())
117
-
118
- scored.append(
119
- {
120
- "candidate_id": str(cand.id),
121
- "qdrant_id": qid,
122
- "name": cand.name,
123
- "email": cand.email,
124
- "role_type": cand.role_type,
125
- "engineer_type": cand.engineer_type,
126
- "years_of_experience": cand.years_of_experience,
127
- "most_recent_company": cand.most_recent_company,
128
- "parsed_summary": cand.parsed_summary,
129
- "parsed_skills": cand.parsed_skills,
130
- "parsed_work_experience": cand.parsed_work_experience or [],
131
- "programming_languages": cand.programming_languages or [],
132
- "backend_frameworks": cand.backend_frameworks or [],
133
- "frontend_technologies": cand.frontend_technologies or [],
134
- "growth_velocity": cand.growth_velocity,
135
- "stage1_score": round(total, 4),
136
- "component_scores": {k: round(v, 4) for k, v in components.items()},
137
- }
138
- )
139
 
140
  scored.sort(key=lambda x: x["stage1_score"], reverse=True)
141
  return scored[:50]
 
2
  from qdrant_client import QdrantClient
3
  from qdrant_client.models import Filter, FieldCondition, MatchValue, Range
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy import select
6
 
7
  from ..config import get_settings
8
  from ..models.candidate import Candidate
9
  from ..ml.embedder import embed_query
10
+ from ..ml.feature_builder import skill_jaccard, yoe_match, company_quality_signal, education_match
 
 
 
 
 
11
 
12
 
13
  DEFAULT_WEIGHTS = {
14
+ "semantic": 0.20, "skill": 0.35, "yoe": 0.15,
15
+ "company": 0.10, "growth": 0.10, "education": 0.10,
 
 
 
 
16
  }
17
 
18
 
19
+ def _build_qdrant_filter(jd: dict, session_id: str | None) -> Filter | None:
20
  conditions = []
21
+ if session_id:
22
+ conditions.append(FieldCondition(key="session_id", match=MatchValue(value=session_id)))
 
 
23
  if jd.get("min_yoe") is not None:
24
+ conditions.append(FieldCondition(key="years_of_experience", range=Range(gte=max(0, jd["min_yoe"] - 2))))
 
 
25
  if not conditions:
26
  return None
27
  return Filter(must=conditions)
 
31
  jd: dict,
32
  db: AsyncSession,
33
  qdrant: QdrantClient,
34
+ session_id: str | None = None,
35
  top_k: int = 200,
36
  weights: dict | None = None,
37
  ) -> list[dict[str, Any]]:
 
41
  jd_text = f"{jd.get('title', '')} {jd.get('raw_text', '')}"
42
  query_vector = embed_query(jd_text)
43
 
44
+ qdrant_filter = _build_qdrant_filter(jd, session_id)
45
  search_results = qdrant.search(
46
  collection_name=settings.collection_name,
47
  query_vector=query_vector.tolist(),
 
56
  qdrant_ids = [r.id for r in search_results]
57
  score_by_qdrant_id = {r.id: float(r.score) for r in search_results}
58
 
59
+ result = await db.execute(select(Candidate).where(Candidate.qdrant_id.in_(qdrant_ids)))
 
 
60
  candidates = {c.qdrant_id: c for c in result.scalars().all()}
61
 
62
  jd_skills = jd.get("required_skills") or []
 
70
  continue
71
 
72
  cosine_sim = score_by_qdrant_id[qid]
 
73
  all_cand_skills = (
74
+ (cand.programming_languages or []) + (cand.backend_frameworks or []) + (cand.frontend_technologies or [])
 
 
75
  )
76
  if cand.parsed_skills:
77
  all_cand_skills.extend([s.strip() for s in cand.parsed_skills.split(",") if s.strip()])
 
80
  "semantic": cosine_sim,
81
  "skill": skill_jaccard(jd_skills, all_cand_skills),
82
  "yoe": yoe_match(min_yoe, max_yoe, cand.years_of_experience),
83
+ "company": company_quality_signal({
84
+ "most_recent_company_is_funded": cand.most_recent_company_is_funded,
85
+ "most_recent_company_is_product_company": cand.most_recent_company_is_product_company,
86
+ "most_recent_company_total_funding": cand.most_recent_company_total_funding,
87
+ }),
 
 
88
  "growth": float(cand.growth_velocity or 0.5),
89
+ "education": education_match({"degree": cand.degree, "education_status": cand.education_status}),
 
 
 
 
 
90
  }
91
 
92
  total = sum(w.get(k, 0) * v for k, v in components.items())
93
+ scored.append({
94
+ "candidate_id": str(cand.id),
95
+ "qdrant_id": qid,
96
+ "name": cand.name,
97
+ "email": cand.email,
98
+ "role_type": cand.role_type,
99
+ "engineer_type": cand.engineer_type,
100
+ "years_of_experience": cand.years_of_experience,
101
+ "most_recent_company": cand.most_recent_company,
102
+ "parsed_summary": cand.parsed_summary,
103
+ "parsed_skills": cand.parsed_skills,
104
+ "parsed_work_experience": cand.parsed_work_experience or [],
105
+ "programming_languages": cand.programming_languages or [],
106
+ "backend_frameworks": cand.backend_frameworks or [],
107
+ "frontend_technologies": cand.frontend_technologies or [],
108
+ "growth_velocity": cand.growth_velocity,
109
+ "stage1_score": round(total, 4),
110
+ "component_scores": {k: round(v, 4) for k, v in components.items()},
111
+ })
 
 
 
112
 
113
  scored.sort(key=lambda x: x["stage1_score"], reverse=True)
114
  return scored[:50]
backend/src/routers/matching.py CHANGED
@@ -1,6 +1,6 @@
1
  import uuid
2
  from datetime import datetime, timezone
3
- from fastapi import APIRouter, Depends, HTTPException, Request
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
  from sqlalchemy import select, delete
6
 
@@ -8,7 +8,10 @@ from ..database import get_db
8
  from ..models.jd import JobDescription
9
  from ..models.candidate import Candidate
10
  from ..models.match_result import MatchResult
11
- from ..schemas.match import MatchResponse, MatchedCandidate, ComponentScores, GapItem, CandidateDetailResponse, ReRankRequest
 
 
 
12
  from ..matching.stage1 import stage1_retrieve
13
  from ..matching.stage2 import stage2_rerank
14
  from ..matching.llm_explainer import generate_explanation
@@ -31,39 +34,64 @@ async def _load_jd(jd_id: uuid.UUID, db: AsyncSession) -> JobDescription:
31
  return jd
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @router.post("/{jd_id}", response_model=MatchResponse)
35
  async def trigger_match(
36
  jd_id: uuid.UUID,
37
  request: Request,
 
38
  db: AsyncSession = Depends(get_db),
39
  ):
40
  jd = await _load_jd(jd_id, db)
41
  qdrant = _get_qdrant(request)
 
 
42
 
43
- jd_dict = {
44
- "id": str(jd.id),
45
- "title": jd.title,
46
- "raw_text": jd.raw_text,
47
- "required_skills": jd.required_skills or [],
48
- "min_yoe": jd.min_yoe,
49
- "max_yoe": jd.max_yoe,
50
- "role_type": jd.role_type,
51
- "engineer_type": jd.engineer_type,
52
- "location": jd.location,
53
- "remote_allowed": jd.remote_allowed,
54
- }
55
-
56
- shortlist = await stage1_retrieve(jd_dict, db, qdrant)
57
  final_ranked = await stage2_rerank(jd_dict, shortlist)
58
 
59
- await db.execute(delete(MatchResult).where(MatchResult.jd_id == jd_id))
 
 
 
 
 
60
 
61
- match_records = []
62
  for i, item in enumerate(final_ranked):
63
  mr = MatchResult(
64
- id=uuid.uuid4(),
65
- jd_id=jd_id,
66
  candidate_id=uuid.UUID(item["candidate_id"]),
 
67
  rank=i + 1,
68
  stage1_score=item.get("stage1_score", 0),
69
  stage2_score=item.get("stage2_score"),
@@ -71,54 +99,40 @@ async def trigger_match(
71
  component_scores=item.get("component_scores", {}),
72
  gaps=item.get("gaps", []),
73
  )
74
- match_records.append(mr)
75
  db.add(mr)
76
 
77
  await db.commit()
78
 
79
- results = []
80
- for i, item in enumerate(final_ranked):
81
- results.append(
82
- MatchedCandidate(
83
- candidate_id=uuid.UUID(item["candidate_id"]),
84
- rank=i + 1,
85
- name=item.get("name"),
86
- email=item.get("email"),
87
- role_type=item.get("role_type"),
88
- engineer_type=item.get("engineer_type"),
89
- years_of_experience=item.get("years_of_experience"),
90
- most_recent_company=item.get("most_recent_company"),
91
- parsed_summary=item.get("parsed_summary"),
92
- programming_languages=item.get("programming_languages") or [],
93
- growth_velocity=item.get("growth_velocity", 0.5),
94
- stage1_score=item.get("stage1_score", 0),
95
- stage2_score=item.get("stage2_score"),
96
- final_score=item.get("final_score", 0),
97
- component_scores=ComponentScores(**item.get("component_scores", {})),
98
- gaps=[GapItem(**g) for g in item.get("gaps", [])],
99
- )
100
- )
101
-
102
  return MatchResponse(
103
- jd_id=jd_id,
104
- jd_title=jd.title,
105
  jd_quality=jd.jd_quality or {},
106
- total_matched=len(results),
107
- results=results,
108
  weights_used={"semantic": 0.20, "skill": 0.35, "yoe": 0.15, "company": 0.10, "growth": 0.10, "education": 0.10},
 
109
  )
110
 
111
 
112
  @router.get("/{jd_id}", response_model=MatchResponse)
113
- async def get_match_results(jd_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
 
 
 
 
114
  jd = await _load_jd(jd_id, db)
115
 
116
- result = await db.execute(
117
  select(MatchResult, Candidate)
118
  .join(Candidate, MatchResult.candidate_id == Candidate.id)
119
  .where(MatchResult.jd_id == jd_id)
120
- .order_by(MatchResult.rank)
121
  )
 
 
 
 
 
 
 
122
  rows = result.all()
123
 
124
  if not rows:
@@ -126,33 +140,73 @@ async def get_match_results(jd_id: uuid.UUID, db: AsyncSession = Depends(get_db)
126
 
127
  results = []
128
  for mr, cand in rows:
129
- results.append(
130
- MatchedCandidate(
131
- candidate_id=cand.id,
132
- rank=mr.rank or 0,
133
- name=cand.name,
134
- email=cand.email,
135
- role_type=cand.role_type,
136
- engineer_type=cand.engineer_type,
137
- years_of_experience=cand.years_of_experience,
138
- most_recent_company=cand.most_recent_company,
139
- parsed_summary=cand.parsed_summary,
140
- programming_languages=cand.programming_languages or [],
141
- growth_velocity=cand.growth_velocity,
142
- stage1_score=mr.stage1_score,
143
- stage2_score=mr.stage2_score,
144
- final_score=mr.final_score,
145
- component_scores=ComponentScores(**(mr.component_scores or {})),
146
- gaps=[GapItem(**g) for g in (mr.gaps or [])],
147
- )
148
- )
149
 
150
  return MatchResponse(
151
- jd_id=jd_id,
152
- jd_title=jd.title,
153
- jd_quality=jd.jd_quality or {},
154
- total_matched=len(results),
155
- results=results,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
 
@@ -160,19 +214,18 @@ async def get_match_results(jd_id: uuid.UUID, db: AsyncSession = Depends(get_db)
160
  async def get_candidate_detail(
161
  jd_id: uuid.UUID,
162
  candidate_id: uuid.UUID,
 
163
  db: AsyncSession = Depends(get_db),
164
  ):
165
  jd = await _load_jd(jd_id, db)
166
 
167
- mr_result = await db.execute(
168
- select(MatchResult).where(
169
- MatchResult.jd_id == jd_id,
170
- MatchResult.candidate_id == candidate_id,
171
- )
172
- )
173
  mr = mr_result.scalar_one_or_none()
174
  if not mr:
175
- raise HTTPException(status_code=404, detail="Match result not found for this JD/candidate pair")
176
 
177
  cand_result = await db.execute(select(Candidate).where(Candidate.id == candidate_id))
178
  cand = cand_result.scalar_one_or_none()
@@ -180,19 +233,9 @@ async def get_candidate_detail(
180
  raise HTTPException(status_code=404, detail="Candidate not found")
181
 
182
  if not mr.explanation:
183
- jd_dict = {
184
- "id": str(jd.id),
185
- "title": jd.title,
186
- "raw_text": jd.raw_text,
187
- "required_skills": jd.required_skills or [],
188
- "min_yoe": jd.min_yoe,
189
- "engineer_type": jd.engineer_type,
190
- "location": jd.location,
191
- "remote_allowed": jd.remote_allowed,
192
- }
193
  cand_dict = {
194
- "parsed_summary": cand.parsed_summary,
195
- "parsed_skills": cand.parsed_skills,
196
  "years_of_experience": cand.years_of_experience,
197
  "programming_languages": cand.programming_languages or [],
198
  "backend_frameworks": cand.backend_frameworks or [],
@@ -206,113 +249,28 @@ async def get_candidate_detail(
206
  await db.commit()
207
 
208
  return CandidateDetailResponse(
209
- jd_id=jd_id,
210
- candidate_id=candidate_id,
211
- rank=mr.rank,
212
  final_score=mr.final_score,
213
  component_scores=ComponentScores(**(mr.component_scores or {})),
214
  gaps=[GapItem(**g) for g in (mr.gaps or [])],
215
  explanation=mr.explanation,
216
  candidate={
217
- "name": cand.name,
218
- "email": cand.email,
219
- "role_type": cand.role_type,
220
- "engineer_type": cand.engineer_type,
221
- "years_of_experience": cand.years_of_experience,
222
- "most_recent_company": cand.most_recent_company,
223
- "parsed_summary": cand.parsed_summary,
224
- "parsed_skills": cand.parsed_skills,
225
- "parsed_work_experience": cand.parsed_work_experience or [],
226
  "programming_languages": cand.programming_languages or [],
227
  "backend_frameworks": cand.backend_frameworks or [],
228
- "gen_ai_experience": cand.gen_ai_experience,
229
- "growth_velocity": cand.growth_velocity,
230
- "looking_for": cand.looking_for,
231
- "open_to_working_at": cand.open_to_working_at,
232
  "is_actively_or_passively_looking": cand.is_actively_or_passively_looking,
233
  "most_recent_company_is_funded": cand.most_recent_company_is_funded,
234
  "most_recent_company_is_product_company": cand.most_recent_company_is_product_company,
235
  "most_recent_company_total_funding": cand.most_recent_company_total_funding,
236
  },
237
  jd={
238
- "title": jd.title,
239
- "required_skills": jd.required_skills or [],
240
- "min_yoe": jd.min_yoe,
241
- "role_type": jd.role_type,
242
- "engineer_type": jd.engineer_type,
243
- "location": jd.location,
244
  },
245
  )
246
-
247
-
248
- @router.post("/{jd_id}/rerank", response_model=MatchResponse)
249
- async def rerank_results(
250
- jd_id: uuid.UUID,
251
- payload: ReRankRequest,
252
- db: AsyncSession = Depends(get_db),
253
- ):
254
- jd = await _load_jd(jd_id, db)
255
-
256
- result = await db.execute(
257
- select(MatchResult, Candidate)
258
- .join(Candidate, MatchResult.candidate_id == Candidate.id)
259
- .where(MatchResult.jd_id == jd_id)
260
- .order_by(MatchResult.rank)
261
- )
262
- rows = result.all()
263
-
264
- if not rows:
265
- raise HTTPException(status_code=404, detail="No match results found.")
266
-
267
- items = []
268
- for mr, cand in rows:
269
- items.append({
270
- "candidate_id": str(cand.id),
271
- "name": cand.name,
272
- "email": cand.email,
273
- "role_type": cand.role_type,
274
- "engineer_type": cand.engineer_type,
275
- "years_of_experience": cand.years_of_experience,
276
- "most_recent_company": cand.most_recent_company,
277
- "parsed_summary": cand.parsed_summary,
278
- "programming_languages": cand.programming_languages or [],
279
- "growth_velocity": cand.growth_velocity,
280
- "stage1_score": mr.stage1_score,
281
- "stage2_score": mr.stage2_score,
282
- "final_score": mr.final_score,
283
- "component_scores": mr.component_scores or {},
284
- "gaps": mr.gaps or [],
285
- })
286
-
287
- reranked = rerank_with_weights(items, payload.weights)
288
-
289
- results = [
290
- MatchedCandidate(
291
- candidate_id=uuid.UUID(item["candidate_id"]),
292
- rank=item["rank"],
293
- name=item.get("name"),
294
- email=item.get("email"),
295
- role_type=item.get("role_type"),
296
- engineer_type=item.get("engineer_type"),
297
- years_of_experience=item.get("years_of_experience"),
298
- most_recent_company=item.get("most_recent_company"),
299
- parsed_summary=item.get("parsed_summary"),
300
- programming_languages=item.get("programming_languages") or [],
301
- growth_velocity=item.get("growth_velocity", 0.5),
302
- stage1_score=item.get("stage1_score", 0),
303
- stage2_score=item.get("stage2_score"),
304
- final_score=item.get("final_score", 0),
305
- component_scores=ComponentScores(**(item.get("component_scores") or {})),
306
- gaps=[GapItem(**g) for g in item.get("gaps", [])],
307
- )
308
- for item in reranked
309
- ]
310
-
311
- return MatchResponse(
312
- jd_id=jd_id,
313
- jd_title=jd.title,
314
- jd_quality=jd.jd_quality or {},
315
- total_matched=len(results),
316
- results=results,
317
- weights_used=payload.weights,
318
- )
 
1
  import uuid
2
  from datetime import datetime, timezone
3
+ from fastapi import APIRouter, Depends, HTTPException, Request, Query
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
  from sqlalchemy import select, delete
6
 
 
8
  from ..models.jd import JobDescription
9
  from ..models.candidate import Candidate
10
  from ..models.match_result import MatchResult
11
+ from ..schemas.match import (
12
+ MatchResponse, MatchedCandidate, ComponentScores, GapItem,
13
+ CandidateDetailResponse, ReRankRequest,
14
+ )
15
  from ..matching.stage1 import stage1_retrieve
16
  from ..matching.stage2 import stage2_rerank
17
  from ..matching.llm_explainer import generate_explanation
 
34
  return jd
35
 
36
 
37
+ def _build_jd_dict(jd: JobDescription) -> dict:
38
+ return {
39
+ "id": str(jd.id), "title": jd.title, "raw_text": jd.raw_text,
40
+ "required_skills": jd.required_skills or [], "min_yoe": jd.min_yoe,
41
+ "max_yoe": jd.max_yoe, "role_type": jd.role_type,
42
+ "engineer_type": jd.engineer_type, "location": jd.location,
43
+ "remote_allowed": jd.remote_allowed,
44
+ }
45
+
46
+
47
+ def _to_matched_candidate(item: dict, rank: int) -> MatchedCandidate:
48
+ return MatchedCandidate(
49
+ candidate_id=uuid.UUID(item["candidate_id"]),
50
+ rank=rank,
51
+ name=item.get("name"),
52
+ email=item.get("email"),
53
+ role_type=item.get("role_type"),
54
+ engineer_type=item.get("engineer_type"),
55
+ years_of_experience=item.get("years_of_experience"),
56
+ most_recent_company=item.get("most_recent_company"),
57
+ parsed_summary=item.get("parsed_summary"),
58
+ programming_languages=item.get("programming_languages") or [],
59
+ growth_velocity=item.get("growth_velocity", 0.5),
60
+ stage1_score=item.get("stage1_score", 0),
61
+ stage2_score=item.get("stage2_score"),
62
+ final_score=item.get("final_score", 0),
63
+ component_scores=ComponentScores(**(item.get("component_scores") or {})),
64
+ gaps=[GapItem(**g) for g in item.get("gaps", [])],
65
+ )
66
+
67
+
68
  @router.post("/{jd_id}", response_model=MatchResponse)
69
  async def trigger_match(
70
  jd_id: uuid.UUID,
71
  request: Request,
72
+ session_id: uuid.UUID | None = Query(None, description="Candidate session to match against"),
73
  db: AsyncSession = Depends(get_db),
74
  ):
75
  jd = await _load_jd(jd_id, db)
76
  qdrant = _get_qdrant(request)
77
+ jd_dict = _build_jd_dict(jd)
78
+ sid_str = str(session_id) if session_id else None
79
 
80
+ shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  final_ranked = await stage2_rerank(jd_dict, shortlist)
82
 
83
+ await db.execute(
84
+ delete(MatchResult).where(
85
+ MatchResult.jd_id == jd_id,
86
+ MatchResult.session_id == session_id if session_id else MatchResult.session_id.is_(None),
87
+ )
88
+ )
89
 
 
90
  for i, item in enumerate(final_ranked):
91
  mr = MatchResult(
92
+ id=uuid.uuid4(), jd_id=jd_id,
 
93
  candidate_id=uuid.UUID(item["candidate_id"]),
94
+ session_id=session_id,
95
  rank=i + 1,
96
  stage1_score=item.get("stage1_score", 0),
97
  stage2_score=item.get("stage2_score"),
 
99
  component_scores=item.get("component_scores", {}),
100
  gaps=item.get("gaps", []),
101
  )
 
102
  db.add(mr)
103
 
104
  await db.commit()
105
 
106
+ results = [_to_matched_candidate(item, i + 1) for i, item in enumerate(final_ranked)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  return MatchResponse(
108
+ jd_id=jd_id, jd_title=jd.title,
 
109
  jd_quality=jd.jd_quality or {},
110
+ total_matched=len(results), results=results,
 
111
  weights_used={"semantic": 0.20, "skill": 0.35, "yoe": 0.15, "company": 0.10, "growth": 0.10, "education": 0.10},
112
+ session_id=session_id,
113
  )
114
 
115
 
116
  @router.get("/{jd_id}", response_model=MatchResponse)
117
+ async def get_match_results(
118
+ jd_id: uuid.UUID,
119
+ session_id: uuid.UUID | None = Query(None),
120
+ db: AsyncSession = Depends(get_db),
121
+ ):
122
  jd = await _load_jd(jd_id, db)
123
 
124
+ q = (
125
  select(MatchResult, Candidate)
126
  .join(Candidate, MatchResult.candidate_id == Candidate.id)
127
  .where(MatchResult.jd_id == jd_id)
 
128
  )
129
+ if session_id:
130
+ q = q.where(MatchResult.session_id == session_id)
131
+ else:
132
+ q = q.where(MatchResult.session_id.is_(None))
133
+ q = q.order_by(MatchResult.rank)
134
+
135
+ result = await db.execute(q)
136
  rows = result.all()
137
 
138
  if not rows:
 
140
 
141
  results = []
142
  for mr, cand in rows:
143
+ item = {
144
+ "candidate_id": str(cand.id), "name": cand.name, "email": cand.email,
145
+ "role_type": cand.role_type, "engineer_type": cand.engineer_type,
146
+ "years_of_experience": cand.years_of_experience,
147
+ "most_recent_company": cand.most_recent_company,
148
+ "parsed_summary": cand.parsed_summary,
149
+ "programming_languages": cand.programming_languages or [],
150
+ "growth_velocity": cand.growth_velocity,
151
+ "stage1_score": mr.stage1_score, "stage2_score": mr.stage2_score,
152
+ "final_score": mr.final_score,
153
+ "component_scores": mr.component_scores or {}, "gaps": mr.gaps or [],
154
+ }
155
+ results.append(_to_matched_candidate(item, mr.rank or 0))
 
 
 
 
 
 
 
156
 
157
  return MatchResponse(
158
+ jd_id=jd_id, jd_title=jd.title, jd_quality=jd.jd_quality or {},
159
+ total_matched=len(results), results=results, session_id=session_id,
160
+ )
161
+
162
+
163
+ @router.post("/{jd_id}/rerank", response_model=MatchResponse)
164
+ async def rerank_results(
165
+ jd_id: uuid.UUID,
166
+ payload: ReRankRequest,
167
+ session_id: uuid.UUID | None = Query(None),
168
+ db: AsyncSession = Depends(get_db),
169
+ ):
170
+ jd = await _load_jd(jd_id, db)
171
+
172
+ q = (
173
+ select(MatchResult, Candidate)
174
+ .join(Candidate, MatchResult.candidate_id == Candidate.id)
175
+ .where(MatchResult.jd_id == jd_id)
176
+ )
177
+ if session_id:
178
+ q = q.where(MatchResult.session_id == session_id)
179
+ else:
180
+ q = q.where(MatchResult.session_id.is_(None))
181
+ q = q.order_by(MatchResult.rank)
182
+
183
+ result = await db.execute(q)
184
+ rows = result.all()
185
+ if not rows:
186
+ raise HTTPException(status_code=404, detail="No match results found.")
187
+
188
+ items = [
189
+ {
190
+ "candidate_id": str(cand.id), "name": cand.name, "email": cand.email,
191
+ "role_type": cand.role_type, "engineer_type": cand.engineer_type,
192
+ "years_of_experience": cand.years_of_experience,
193
+ "most_recent_company": cand.most_recent_company,
194
+ "parsed_summary": cand.parsed_summary,
195
+ "programming_languages": cand.programming_languages or [],
196
+ "growth_velocity": cand.growth_velocity,
197
+ "stage1_score": mr.stage1_score, "stage2_score": mr.stage2_score,
198
+ "final_score": mr.final_score,
199
+ "component_scores": mr.component_scores or {}, "gaps": mr.gaps or [],
200
+ }
201
+ for mr, cand in rows
202
+ ]
203
+
204
+ reranked = rerank_with_weights(items, payload.weights)
205
+ results = [_to_matched_candidate(item, item["rank"]) for item in reranked]
206
+ return MatchResponse(
207
+ jd_id=jd_id, jd_title=jd.title, jd_quality=jd.jd_quality or {},
208
+ total_matched=len(results), results=results,
209
+ weights_used=payload.weights, session_id=session_id,
210
  )
211
 
212
 
 
214
  async def get_candidate_detail(
215
  jd_id: uuid.UUID,
216
  candidate_id: uuid.UUID,
217
+ session_id: uuid.UUID | None = Query(None),
218
  db: AsyncSession = Depends(get_db),
219
  ):
220
  jd = await _load_jd(jd_id, db)
221
 
222
+ q = select(MatchResult).where(MatchResult.jd_id == jd_id, MatchResult.candidate_id == candidate_id)
223
+ if session_id:
224
+ q = q.where(MatchResult.session_id == session_id)
225
+ mr_result = await db.execute(q)
 
 
226
  mr = mr_result.scalar_one_or_none()
227
  if not mr:
228
+ raise HTTPException(status_code=404, detail="Match result not found")
229
 
230
  cand_result = await db.execute(select(Candidate).where(Candidate.id == candidate_id))
231
  cand = cand_result.scalar_one_or_none()
 
233
  raise HTTPException(status_code=404, detail="Candidate not found")
234
 
235
  if not mr.explanation:
236
+ jd_dict = _build_jd_dict(jd)
 
 
 
 
 
 
 
 
 
237
  cand_dict = {
238
+ "parsed_summary": cand.parsed_summary, "parsed_skills": cand.parsed_skills,
 
239
  "years_of_experience": cand.years_of_experience,
240
  "programming_languages": cand.programming_languages or [],
241
  "backend_frameworks": cand.backend_frameworks or [],
 
249
  await db.commit()
250
 
251
  return CandidateDetailResponse(
252
+ jd_id=jd_id, candidate_id=candidate_id, rank=mr.rank,
 
 
253
  final_score=mr.final_score,
254
  component_scores=ComponentScores(**(mr.component_scores or {})),
255
  gaps=[GapItem(**g) for g in (mr.gaps or [])],
256
  explanation=mr.explanation,
257
  candidate={
258
+ "name": cand.name, "email": cand.email, "role_type": cand.role_type,
259
+ "engineer_type": cand.engineer_type, "years_of_experience": cand.years_of_experience,
260
+ "most_recent_company": cand.most_recent_company, "parsed_summary": cand.parsed_summary,
261
+ "parsed_skills": cand.parsed_skills, "parsed_work_experience": cand.parsed_work_experience or [],
 
 
 
 
 
262
  "programming_languages": cand.programming_languages or [],
263
  "backend_frameworks": cand.backend_frameworks or [],
264
+ "gen_ai_experience": cand.gen_ai_experience, "growth_velocity": cand.growth_velocity,
265
+ "looking_for": cand.looking_for, "open_to_working_at": cand.open_to_working_at,
 
 
266
  "is_actively_or_passively_looking": cand.is_actively_or_passively_looking,
267
  "most_recent_company_is_funded": cand.most_recent_company_is_funded,
268
  "most_recent_company_is_product_company": cand.most_recent_company_is_product_company,
269
  "most_recent_company_total_funding": cand.most_recent_company_total_funding,
270
  },
271
  jd={
272
+ "title": jd.title, "required_skills": jd.required_skills or [],
273
+ "min_yoe": jd.min_yoe, "role_type": jd.role_type,
274
+ "engineer_type": jd.engineer_type, "location": jd.location,
 
 
 
275
  },
276
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/test_db.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from src.database import engine
3
+
4
+ async def test():
5
+ async with engine.begin() as conn:
6
+ await conn.run_sync(lambda *args: print('DB Connection OK'))
7
+
8
+ asyncio.run(test())