indhupamula commited on
Commit
818eae0
·
verified ·
1 Parent(s): 78587d6

Upload api.py

Browse files
Files changed (1) hide show
  1. api.py +90 -0
api.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ import faiss
5
+ from fastapi import FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from sentence_transformers import SentenceTransformer
9
+ import os
10
+ import requests
11
+ from bs4 import BeautifulSoup
12
+
13
+ app = FastAPI(title="SHL Assessment Recommender")
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_methods=["*"],
19
+ allow_headers=["*"]
20
+ )
21
+
22
+ print("Loading models...")
23
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
24
+ faiss_index = faiss.read_index("vector_index.faiss")
25
+
26
+ with open("assessment_data.pkl", "rb") as f:
27
+ stored_assessments = pickle.load(f)
28
+
29
+ print("Models loaded!")
30
+
31
+ class QueryInput(BaseModel):
32
+ query: str
33
+
34
+ def find_similar(query, top_k=10):
35
+ vec = embedding_model.encode([query]).astype("float32")
36
+ _, indices = faiss_index.search(vec, top_k)
37
+ results = []
38
+ for idx in indices[0]:
39
+ if idx < len(stored_assessments):
40
+ results.append(stored_assessments[idx])
41
+ return results
42
+
43
+ def balance(candidates):
44
+ p_types = ["P", "B"]
45
+ k_types = ["K", "A"]
46
+ p_items = [i for i in candidates if any(t in p_types for t in i.get("test_type", []))]
47
+ k_items = [i for i in candidates if any(t in k_types for t in i.get("test_type", []))]
48
+ o_items = [i for i in candidates if i not in p_items and i not in k_items]
49
+ result = []
50
+ result.extend(p_items[:3])
51
+ result.extend(k_items[:3])
52
+ result.extend(o_items[:4])
53
+ remaining = [i for i in candidates if i not in result]
54
+ while len(result) < 10 and remaining:
55
+ result.append(remaining.pop(0))
56
+ return result[:10]
57
+
58
+ @app.get("/health")
59
+ def health():
60
+ return {"status": "healthy"}
61
+
62
+ @app.post("/recommend")
63
+ def recommend(user_input: QueryInput):
64
+ query = user_input.query
65
+ if query.startswith("http"):
66
+ try:
67
+ r = requests.get(query, timeout=10)
68
+ soup = BeautifulSoup(r.content, "html.parser")
69
+ query = soup.get_text(separator=" ", strip=True)[:3000]
70
+ except:
71
+ pass
72
+ candidates = find_similar(query, top_k=20)
73
+ final = balance(candidates)
74
+ response = []
75
+ for item in final:
76
+ response.append({
77
+ "url": item.get("url", ""),
78
+ "name": item.get("name", ""),
79
+ "adaptive_support": item.get("adaptive_support", "No"),
80
+ "description": item.get("description", ""),
81
+ "duration": item.get("duration", 0),
82
+ "remote_support": item.get("remote_support", "No"),
83
+ "test_type": item.get("test_type", [])
84
+ })
85
+ return {"recommended_assessments": response}
86
+
87
+ if __name__ == "__main__":
88
+ import uvicorn
89
+ port = int(os.environ.get("PORT", 8000))
90
+ uvicorn.run(app, host="0.0.0.0", port=port)