File size: 1,898 Bytes
5bb3151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from gradio_client import Client
import uuid

# Point to your deployed Gradio app
client = Client("IotaCluster/embedding-model")


def get_dense_embedding(text: str):
    """Calls the /embed_dense endpoint (MiniLM)."""
    return _call_api(text, api_name="/embed_dense")


def get_sparse_embedding(text: str):
    """Calls the /embed_sparse endpoint (BM25)."""
    return _call_api(text, api_name="/embed_sparse")


def get_late_embedding(text: str):
    """Calls the /embed_colbert endpoint (ColBERT late-interaction)."""
    return _call_api(text, api_name="/embed_colbert")


def _call_api(text: str, api_name: str):
    try:
        result = client.predict(
            text=text,
            api_name=api_name
        )
        # Normalize response: return the first value in the dict or the list itself
        if isinstance(result, dict):
            return next(iter(result.values()))
        elif isinstance(result, list):
            return result
        else:
            print(f"Unexpected response from {api_name!r}:", result)
            return None
    except Exception as e:
        print(f"API request to {api_name!r} failed:", e)
        return None

def to_valid_qdrant_id(id_val):
    """Ensures the ID is a valid UUID (for Qdrant)."""
    try:
        return str(uuid.UUID(str(id_val)))
    except Exception:
        return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(id_val)))

if __name__ == "__main__":
    text = "Hello!!"

    # dense = get_dense_embedding(text)
    # print("Dense embedding:", dense)
    # print("Length:", len(dense) if dense else None)

    # sparse = get_sparse_embedding(text)
    # print("\nSparse term-weights:", sparse)

    late = get_late_embedding(text)
    print("\nLate-interaction (ColBERT) embeddings:", late)
    print("Tokens count:", len(late) if isinstance(late, list) else None)