Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| import chromadb | |
| import numpy as np | |
| from huggingface_hub import InferenceClient | |
| from scipy.spatial.distance import cosine | |
| import os | |
| app = FastAPI() | |
| HF_API_TOKEN = os.getenv("HF_API_KEY") | |
| client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2", token=HF_API_TOKEN, timeout=1000000) | |
| # Initialize ChromaDB | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") # Persistent storage | |
| collection = chroma_client.get_or_create_collection(name="courses") | |
| def get_embedding(text): | |
| response = client.post(json={"inputs": text}, task="feature-extraction") | |
| if hasattr(response, 'tolist'): | |
| return response.tolist() # Handle if it's already a NumPy array | |
| elif isinstance(response, list): | |
| if len(response) > 0 and isinstance(response[0], list): | |
| return response[0] # Return first item if response is a list of lists | |
| else: | |
| return response # Return as is if it's a flat list | |
| else: | |
| # Convert from bytes if needed (which seems to be your issue) | |
| try: | |
| if isinstance(response, bytes): | |
| import ast | |
| return ast.literal_eval(response.decode('utf-8')) | |
| else: | |
| return response | |
| except: | |
| raise ValueError(f"Unexpected embedding format: {response}") | |
| def find_similar_courses(query_text): | |
| query_embedding = get_embedding(query_text) | |
| # Retrieve stored embeddings | |
| results = collection.get(include=["embeddings", "metadatas"]) | |
| courses = results["metadatas"] | |
| stored_embeddings = np.array(results["embeddings"]) | |
| # Compute cosine similarities | |
| similarities = [1 - cosine(query_embedding, emb) for emb in stored_embeddings] | |
| # Get top courses | |
| top_indices = np.argsort(similarities)[-15:][::-1] | |
| top_courses = [courses[i] for i in top_indices] | |
| return top_courses | |
| async def health(): | |
| return {"status": "OK"} | |
| async def search(query: str): # Accept `query` directly as a parameter | |
| if not query: | |
| return {"error": "Query parameter is required."} | |
| top_courses = find_similar_courses(query) | |
| return [course["course_id"] for course in top_courses] | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face requires port 7860 |