shreyankisiri's picture
Update main.py
9430f1e verified
from fastapi import FastAPI
import chromadb
import numpy as np
from huggingface_hub import InferenceClient
from scipy.spatial.distance import cosine
import os
app = FastAPI()
HF_API_TOKEN = os.getenv("HF_API_KEY")
client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2", token=HF_API_TOKEN, timeout=1000000)
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path="./chroma_db") # Persistent storage
collection = chroma_client.get_or_create_collection(name="courses")
def get_embedding(text):
response = client.post(json={"inputs": text}, task="feature-extraction")
if hasattr(response, 'tolist'):
return response.tolist() # Handle if it's already a NumPy array
elif isinstance(response, list):
if len(response) > 0 and isinstance(response[0], list):
return response[0] # Return first item if response is a list of lists
else:
return response # Return as is if it's a flat list
else:
# Convert from bytes if needed (which seems to be your issue)
try:
if isinstance(response, bytes):
import ast
return ast.literal_eval(response.decode('utf-8'))
else:
return response
except:
raise ValueError(f"Unexpected embedding format: {response}")
def find_similar_courses(query_text):
query_embedding = get_embedding(query_text)
# Retrieve stored embeddings
results = collection.get(include=["embeddings", "metadatas"])
courses = results["metadatas"]
stored_embeddings = np.array(results["embeddings"])
# Compute cosine similarities
similarities = [1 - cosine(query_embedding, emb) for emb in stored_embeddings]
# Get top courses
top_indices = np.argsort(similarities)[-15:][::-1]
top_courses = [courses[i] for i in top_indices]
return top_courses
@app.get("/health")
async def health():
return {"status": "OK"}
@app.post("/search")
async def search(query: str): # Accept `query` directly as a parameter
if not query:
return {"error": "Query parameter is required."}
top_courses = find_similar_courses(query)
return [course["course_id"] for course in top_courses]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face requires port 7860