Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import gradio | |
| from fastapi import FastAPI | |
| from fastapi.responses import JSONResponse | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any | |
| MODELS = [ | |
| "ibm-granite/granite-embedding-30m-english", | |
| "ibm-granite/granite-embedding-278m-multilingual" | |
| ] | |
| current_model = None | |
| model = None | |
| app = FastAPI() | |
| def load_model(model_name: str): | |
| global current_model | |
| if current_model is not None and current_model == model_name: | |
| return current_model | |
| try: | |
| current_model = SentenceTransformer(model_name) | |
| except Exception as ex: | |
| raise ValueError(f"Failed to load model '{model_name}': {str(ex)}") | |
| return current_model | |
| def embed(document: str, model_name: str): | |
| if model_name: | |
| try: | |
| new_model = load_model(model_name) | |
| return new_model.encode(document) | |
| except Exception as ex: | |
| raise ValueError(f"Failed to load model '{model_name}': {str(ex)}") | |
| return None | |
| async def get_models(): | |
| return JSONResponse( | |
| content={ | |
| "models": MODELS | |
| } | |
| ) | |
| async def generate_embedding(data: Dict[str, Any]): | |
| try: | |
| text = data.get("text", "") | |
| model_name = data.get("model","") | |
| if not text: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "No text provided"} | |
| ) | |
| if model_name not in MODELS: | |
| message = f"Only IBM Granite embedding models can be used: {MODELS}" | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": message} | |
| ) | |
| if model_name: | |
| vector_embedding = embed(text, model_name) | |
| return JSONResponse( | |
| content={ | |
| "embedding": vector_embedding.tolist(), | |
| "dim": len(vector_embedding), | |
| "model": model_name | |
| } | |
| ) | |
| except Exception as ex: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(ex)} | |
| ) | |
| with gradio.Blocks(title="Aaron's Granite Text Embeddings service") as gradio_app: | |
| gradio.Markdown("Generate embeddings for your text using the IBM Granite embedding models.") | |
| # Model selector dropdown (allows custom input) | |
| model_dropdown = gradio.Dropdown( | |
| choices=MODELS, | |
| value="", | |
| label="Select Embedding Model", | |
| info="Choose any predefined model name", | |
| allow_custom_value=True | |
| ) | |
| # Create an input text box | |
| text_input = gradio.Textbox(label="Enter text to embed", placeholder="Type or paste your text here...") | |
| # Create an output component to display the embedding | |
| output = gradio.JSON(label="Text Embedding", elem_classes=["json-holder"]) | |
| # Add a submit button with API name | |
| submit_btn = gradio.Button("Generate Embedding", variant="primary") | |
| # Handle both button click and text submission | |
| submit_btn.click(embed, inputs=[text_input, model_dropdown], outputs=output, api_name="predict") | |
| text_input.submit(embed, inputs=[text_input, model_dropdown], outputs=output) | |
| if __name__ == '__main__': | |
| # Mount FastAPI app to Gradio | |
| gradio_app = gradio.mount_gradio_app(app, gradio_app, path="/") | |
| # Run with Uvicorn (Gradio uses this internally) | |
| import uvicorn | |
| uvicorn.run(gradio_app, host="0.0.0.0", port=7860) |