File size: 3,542 Bytes
83ac2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6770e3b
83ac2da
 
6770e3b
83ac2da
 
4a1f813
6770e3b
83ac2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6770e3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import os
import gradio

from fastapi import FastAPI
from fastapi.responses import JSONResponse
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any

MODELS = [
    "ibm-granite/granite-embedding-30m-english",
    "ibm-granite/granite-embedding-278m-multilingual"
]

current_model = None
model = None
app = FastAPI()

def load_model(model_name: str):
    global current_model

    if current_model is not None and current_model == model_name:
        return current_model
    
    try:
        current_model = SentenceTransformer(model_name)
    except Exception as ex:
        raise ValueError(f"Failed to load model '{model_name}': {str(ex)}")

    return current_model

def embed(document: str, model_name: str):
    if model_name:
        try:
            new_model = load_model(model_name)
            return new_model.encode(document)
        except Exception as ex:
            raise ValueError(f"Failed to load model '{model_name}': {str(ex)}")
    
    return None

@app.get("/models")
async def get_models():
    return JSONResponse(
        content={
            "models": MODELS
        }
    )

@app.post("/embed")
async def generate_embedding(data: Dict[str, Any]):
    try:
        text = data.get("text", "")
        model_name = data.get("model","")

        if not text:
            return JSONResponse(
                status_code=400,
                content={"error": "No text provided"}
            )
        
        if model_name not in MODELS:
            message = f"Only IBM Granite embedding models can be used: {MODELS}"
            return JSONResponse(
                status_code=400,
                content={"error": message}
            )
        
        if model_name:
            vector_embedding = embed(text, model_name)
        
            return JSONResponse(
                content={
                    "embedding": vector_embedding.tolist(),
                    "dim": len(vector_embedding),
                    "model": model_name
                }
        )
        
    except Exception as ex:
        return JSONResponse(
            status_code=500,
            content={"error": str(ex)}
        )

with gradio.Blocks(title="Aaron's Granite Text Embeddings service") as gradio_app:
    gradio.Markdown("Generate embeddings for your text using the IBM Granite embedding models.")
    
    # Model selector dropdown (allows custom input)
    model_dropdown = gradio.Dropdown(
        choices=MODELS,
        value="",
        label="Select Embedding Model",
        info="Choose any predefined model name",
        allow_custom_value=True
    )
    
    # Create an input text box
    text_input = gradio.Textbox(label="Enter text to embed", placeholder="Type or paste your text here...")

    # Create an output component to display the embedding
    output = gradio.JSON(label="Text Embedding", elem_classes=["json-holder"])
    
    # Add a submit button with API name
    submit_btn = gradio.Button("Generate Embedding", variant="primary")

    # Handle both button click and text submission
    submit_btn.click(embed, inputs=[text_input, model_dropdown], outputs=output, api_name="predict")
    text_input.submit(embed, inputs=[text_input, model_dropdown], outputs=output)
    
if __name__ == '__main__':
    # Mount FastAPI app to Gradio
    gradio_app = gradio.mount_gradio_app(app, gradio_app, path="/")
    
    # Run with Uvicorn (Gradio uses this internally)
    import uvicorn
    uvicorn.run(gradio_app, host="0.0.0.0", port=7860)