from fastapi import FastAPI import chromadb import numpy as np from huggingface_hub import InferenceClient from scipy.spatial.distance import cosine import os app = FastAPI() HF_API_TOKEN = os.getenv("HF_API_KEY") client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2", token=HF_API_TOKEN, timeout=1000000) # Initialize ChromaDB chroma_client = chromadb.PersistentClient(path="./chroma_db") # Persistent storage collection = chroma_client.get_or_create_collection(name="courses") def get_embedding(text): response = client.post(json={"inputs": text}, task="feature-extraction") if hasattr(response, 'tolist'): return response.tolist() # Handle if it's already a NumPy array elif isinstance(response, list): if len(response) > 0 and isinstance(response[0], list): return response[0] # Return first item if response is a list of lists else: return response # Return as is if it's a flat list else: # Convert from bytes if needed (which seems to be your issue) try: if isinstance(response, bytes): import ast return ast.literal_eval(response.decode('utf-8')) else: return response except: raise ValueError(f"Unexpected embedding format: {response}") def find_similar_courses(query_text): query_embedding = get_embedding(query_text) # Retrieve stored embeddings results = collection.get(include=["embeddings", "metadatas"]) courses = results["metadatas"] stored_embeddings = np.array(results["embeddings"]) # Compute cosine similarities similarities = [1 - cosine(query_embedding, emb) for emb in stored_embeddings] # Get top courses top_indices = np.argsort(similarities)[-15:][::-1] top_courses = [courses[i] for i in top_indices] return top_courses @app.get("/health") async def health(): return {"status": "OK"} @app.post("/search") async def search(query: str): # Accept `query` directly as a parameter if not query: return {"error": "Query parameter is required."} top_courses = find_similar_courses(query) return [course["course_id"] for course in top_courses] if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face requires port 7860