EDU_Recommender / server /src /embeddings.js
Omarrran's picture
Add EduRecommender HuggingFace Spaces app
5bd3663
/**
* Embedding-based retrieval via HuggingFace Inference API.
*
* Uses sentence-transformers/all-MiniLM-L6-v2 through the HF
* sentence-similarity pipeline to compute cosine similarities
* between the user profile and content items directly.
*
* Falls back to a local TF bag-of-words approach when the API
* is unavailable.
*/
const HF_TOKEN = process.env.HF_TOKEN || "";
const EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2";
const HF_INFERENCE_URL = `https://router.huggingface.co/hf-inference/models/${EMBEDDING_MODEL}`;
// ---------------------------------------------------------------------------
// Text representation helpers (matches Python version exactly)
// ---------------------------------------------------------------------------
export function contentText(item) {
return `${item.title}. ${item.description} Tags: ${item.tags.join(", ")}`;
}
export function userText(profile) {
return (
`Goal: ${profile.goal}. ` +
`Interests: ${(profile.interest_tags || []).join(", ")}. ` +
`Learning style: ${profile.learning_style}.`
);
}
// ---------------------------------------------------------------------------
// HuggingFace Inference API — sentence-similarity pipeline
// ---------------------------------------------------------------------------
/**
* Compute cosine similarities between a source sentence and a list of
* target sentences using the HF sentence-similarity pipeline.
*
* Returns an array of floats (one per target sentence).
*/
async function hfSentenceSimilarity(sourceSentence, sentences) {
const resp = await fetch(HF_INFERENCE_URL, {
method: "POST",
headers: {
Authorization: `Bearer ${HF_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
inputs: {
source_sentence: sourceSentence,
sentences,
},
}),
});
if (!resp.ok) {
const body = await resp.text();
throw new Error(`HF Similarity API ${resp.status}: ${body}`);
}
const data = await resp.json();
if (!Array.isArray(data)) {
throw new Error("Unexpected similarity response shape.");
}
return data;
}
// ---------------------------------------------------------------------------
// Local fallback: TF bag-of-words cosine similarity
// ---------------------------------------------------------------------------
function tokenise(text) {
return text
.toLowerCase()
.replace(/[^a-z0-9\s-]/g, " ")
.split(/\s+/)
.filter(Boolean);
}
function localCosineSimilarities(sourceText, targetTexts) {
const allDocs = [...targetTexts, sourceText].map((t) => tokenise(t));
const vocab = new Map();
let idx = 0;
for (const tokens of allDocs) {
for (const t of tokens) {
if (!vocab.has(t)) vocab.set(t, idx++);
}
}
function toVec(tokens) {
const vec = new Float32Array(vocab.size);
for (const t of tokens) {
const i = vocab.get(t);
if (i !== undefined) vec[i] += 1;
}
return vec;
}
function cosine(a, b) {
let dot = 0, nA = 0, nB = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
nA += a[i] * a[i];
nB += b[i] * b[i];
}
const d = Math.sqrt(nA) * Math.sqrt(nB);
return d === 0 ? 0 : dot / d;
}
const sourceVec = toVec(allDocs[allDocs.length - 1]);
return targetTexts.map((_, i) => cosine(toVec(allDocs[i]), sourceVec));
}
// ---------------------------------------------------------------------------
// Public API
// ---------------------------------------------------------------------------
/**
* Retrieve top-K content items most similar to the user profile.
* Uses HF Inference API for real semantic similarity when available.
*
* @param {object} profile - User profile object
* @param {object[]} items - Content items array
* @param {number} k - Number of results to return
* @param {number[]} excludeIds - IDs to exclude (already viewed)
* @returns {{ candidates: object[], similarities: number[], method: string }}
*/
export async function retrieveTopK(profile, items, k = 5, excludeIds = []) {
const excluded = new Set(excludeIds);
const eligibleItems = items.filter((item) => !excluded.has(item.id));
const sourceStr = userText(profile);
const targetStrs = eligibleItems.map((item) => contentText(item));
let similarities;
let method = "local-fallback";
if (HF_TOKEN) {
try {
similarities = await hfSentenceSimilarity(sourceStr, targetStrs);
method = "all-MiniLM-L6-v2";
} catch (err) {
console.warn("[embeddings] HF API failed:", err.message, "— using local fallback.");
}
}
if (!similarities) {
similarities = localCosineSimilarities(sourceStr, targetStrs);
method = "local-fallback";
}
// Pair items with scores and sort
const scored = eligibleItems
.map((item, idx) => ({
item,
score: similarities[idx] || 0,
}))
.sort((a, b) => b.score - a.score)
.slice(0, k);
return {
candidates: scored.map((s) => ({ ...s.item, _simScore: s.score })),
similarities: scored.map((s) => s.score),
method,
};
}