File size: 2,397 Bytes
61a7bc6
5cd49b4
 
 
 
6c08c2b
5cd49b4
61a7bc6
5cd49b4
08e0651
61a7bc6
5cd49b4
 
61a7bc6
5cd49b4
 
 
 
 
 
 
 
 
 
 
 
 
61a7bc6
5cd49b4
 
 
 
 
 
 
 
 
61a7bc6
 
5cd49b4
61a7bc6
 
 
 
 
 
 
 
 
9430f1e
61a7bc6
 
 
5cd49b4
 
61a7bc6
 
82d4b08
41cb8ee
 
0f6b22c
 
41cb8ee
82d4b08
9f18df7
41cb8ee
9f18df7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI
import chromadb
import numpy as np
from huggingface_hub import InferenceClient
from scipy.spatial.distance import cosine
import os 

app = FastAPI()

HF_API_TOKEN = os.getenv("HF_API_KEY")
client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2", token=HF_API_TOKEN, timeout=1000000)

# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path="./chroma_db")  # Persistent storage
collection = chroma_client.get_or_create_collection(name="courses")

def get_embedding(text):
    response = client.post(json={"inputs": text}, task="feature-extraction")
    
    if hasattr(response, 'tolist'):
        return response.tolist()  # Handle if it's already a NumPy array
    elif isinstance(response, list):
        if len(response) > 0 and isinstance(response[0], list):
            return response[0]  # Return first item if response is a list of lists
        else:
            return response  # Return as is if it's a flat list
    else:
        # Convert from bytes if needed (which seems to be your issue)
        try:
            if isinstance(response, bytes):
                import ast
                return ast.literal_eval(response.decode('utf-8'))
            else:
                return response
        except:
            raise ValueError(f"Unexpected embedding format: {response}")

def find_similar_courses(query_text):
    query_embedding = get_embedding(query_text)
    
    # Retrieve stored embeddings
    results = collection.get(include=["embeddings", "metadatas"])
    courses = results["metadatas"]
    stored_embeddings = np.array(results["embeddings"])
    
    # Compute cosine similarities
    similarities = [1 - cosine(query_embedding, emb) for emb in stored_embeddings]
    
    # Get top courses
    top_indices = np.argsort(similarities)[-15:][::-1]
    top_courses = [courses[i] for i in top_indices]
    
    return top_courses

@app.get("/health")
async def health():
    return {"status": "OK"}
@app.post("/search")
async def search(query: str):  # Accept `query` directly as a parameter
    if not query:
        return {"error": "Query parameter is required."}
    
    top_courses = find_similar_courses(query)
    return [course["course_id"] for course in top_courses]


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)  # Hugging Face requires port 7860