File size: 1,689 Bytes
7a884df
 
 
f6f57d2
7a884df
 
 
 
 
f6f57d2
 
 
7a884df
 
f6f57d2
7a884df
 
 
 
 
f6f57d2
 
 
 
 
 
7a884df
 
f6f57d2
 
7a884df
 
 
 
 
 
f6f57d2
 
 
 
7a884df
 
 
 
 
 
 
 
 
 
 
 
f6f57d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a884df
 
 
f6f57d2
7a884df
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import numpy as np
import uvicorn

app = FastAPI()

print("Loading embedding model...")

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

print("Embedding model loaded!")

# -------- REQUEST MODELS --------

class EmbeddingRequest(BaseModel):
    input: str | list[str]


class RouterRequest(BaseModel):
    query: str
    candidates: list[str]


# -------- EMBEDDINGS API --------

@app.post("/v1/embeddings")

async def create_embeddings(req: EmbeddingRequest):

    texts = req.input

    if isinstance(texts, str):
        texts = [texts]

    embeddings = model.encode(
        texts,
        normalize_embeddings=True
    ).tolist()

    data = []

    for i, emb in enumerate(embeddings):
        data.append({
            "embedding": emb,
            "index": i
        })

    return {
        "object": "list",
        "data": data,
        "model": "auric-embedding"
    }


# -------- SEMANTIC ROUTER API --------

@app.post("/v1/router")

async def semantic_router(req: RouterRequest):

    query_embedding = model.encode(
        req.query,
        normalize_embeddings=True
    )

    candidate_embeddings = model.encode(
        req.candidates,
        normalize_embeddings=True
    )

    scores = np.dot(candidate_embeddings, query_embedding)

    best_index = int(np.argmax(scores))

    return {
        "query": req.query,
        "best_match": req.candidates[best_index],
        "score": float(scores[best_index])
    }


# -------- SERVER START --------

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)