Aaron Ploetz
fixing gradio
4a1f813
import json
import os
import gradio
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any
MODELS = [
"ibm-granite/granite-embedding-30m-english",
"ibm-granite/granite-embedding-278m-multilingual"
]
current_model = None
model = None
app = FastAPI()
def load_model(model_name: str):
global current_model
if current_model is not None and current_model == model_name:
return current_model
try:
current_model = SentenceTransformer(model_name)
except Exception as ex:
raise ValueError(f"Failed to load model '{model_name}': {str(ex)}")
return current_model
def embed(document: str, model_name: str):
if model_name:
try:
new_model = load_model(model_name)
return new_model.encode(document)
except Exception as ex:
raise ValueError(f"Failed to load model '{model_name}': {str(ex)}")
return None
@app.get("/models")
async def get_models():
return JSONResponse(
content={
"models": MODELS
}
)
@app.post("/embed")
async def generate_embedding(data: Dict[str, Any]):
try:
text = data.get("text", "")
model_name = data.get("model","")
if not text:
return JSONResponse(
status_code=400,
content={"error": "No text provided"}
)
if model_name not in MODELS:
message = f"Only IBM Granite embedding models can be used: {MODELS}"
return JSONResponse(
status_code=400,
content={"error": message}
)
if model_name:
vector_embedding = embed(text, model_name)
return JSONResponse(
content={
"embedding": vector_embedding.tolist(),
"dim": len(vector_embedding),
"model": model_name
}
)
except Exception as ex:
return JSONResponse(
status_code=500,
content={"error": str(ex)}
)
with gradio.Blocks(title="Aaron's Granite Text Embeddings service") as gradio_app:
gradio.Markdown("Generate embeddings for your text using the IBM Granite embedding models.")
# Model selector dropdown (allows custom input)
model_dropdown = gradio.Dropdown(
choices=MODELS,
value="",
label="Select Embedding Model",
info="Choose any predefined model name",
allow_custom_value=True
)
# Create an input text box
text_input = gradio.Textbox(label="Enter text to embed", placeholder="Type or paste your text here...")
# Create an output component to display the embedding
output = gradio.JSON(label="Text Embedding", elem_classes=["json-holder"])
# Add a submit button with API name
submit_btn = gradio.Button("Generate Embedding", variant="primary")
# Handle both button click and text submission
submit_btn.click(embed, inputs=[text_input, model_dropdown], outputs=output, api_name="predict")
text_input.submit(embed, inputs=[text_input, model_dropdown], outputs=output)
if __name__ == '__main__':
# Mount FastAPI app to Gradio
gradio_app = gradio.mount_gradio_app(app, gradio_app, path="/")
# Run with Uvicorn (Gradio uses this internally)
import uvicorn
uvicorn.run(gradio_app, host="0.0.0.0", port=7860)