gnumanth's picture
Add health check endpoint
348cd23
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForImageTextToText, AutoProcessor
from huggingface_hub import login
# Login with HF token if available
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Language codes
LANGUAGES = {
"en": "English", "de": "German", "fr": "French", "es": "Spanish",
"it": "Italian", "pt": "Portuguese", "nl": "Dutch", "pl": "Polish",
"cs": "Czech", "ru": "Russian", "uk": "Ukrainian", "zh": "Chinese",
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
"bn": "Bengali", "tr": "Turkish", "vi": "Vietnamese", "th": "Thai",
"id": "Indonesian", "ms": "Malay", "sv": "Swedish", "no": "Norwegian",
"da": "Danish", "fi": "Finnish", "el": "Greek", "he": "Hebrew",
"ro": "Romanian", "hu": "Hungarian", "bg": "Bulgarian", "hr": "Croatian",
"sk": "Slovak", "sl": "Slovenian", "sr": "Serbian", "lt": "Lithuanian",
"lv": "Latvian", "et": "Estonian", "sw": "Swahili", "ta": "Tamil",
"te": "Telugu", "mr": "Marathi", "gu": "Gujarati", "kn": "Kannada",
"ml": "Malayalam", "pa": "Punjabi", "ur": "Urdu", "fa": "Persian",
"fil": "Filipino", "ca": "Catalan", "gl": "Galician", "eu": "Basque",
"cy": "Welsh", "ga": "Irish",
}
model_id = "google/translategemma-4b-it"
# Load processor at startup (lightweight)
print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_id)
print("Processor loaded!")
# Model will be loaded on first GPU call
model = None
def load_model():
global model
if model is None:
print("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
).to(device).eval()
print(f"Model loaded on {device.upper()}!")
return model
@spaces.GPU(duration=120)
def translate(text: str, source_lang: str, target_lang: str) -> str:
"""Translate text from source language to target language."""
if not text or not text.strip():
return ""
m = load_model()
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}
]
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(device)
with torch.inference_mode():
generation = m.generate(**inputs, max_new_tokens=1024, do_sample=False)
input_len = inputs["input_ids"].shape[1]
output = processor.decode(generation[0][input_len:], skip_special_tokens=True)
return output.strip()
# MCP HTTP endpoint
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
app = FastAPI()
@app.get("/health")
async def health_check():
return {"status": "ok"}
LANG_CODES = list(LANGUAGES.keys())
@app.post("/mcp")
async def mcp_handler(request: Request):
"""Handle MCP JSON-RPC messages via HTTP POST."""
body = await request.json()
method = body.get("method", "")
params = body.get("params", {})
msg_id = body.get("id")
if method == "initialize":
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {
"name": "translategemma-mcp",
"version": "1.0.0"
}
}
})
elif method == "tools/list":
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {
"tools": [
{
"name": "translate",
"description": "Translate text between 55 languages using TranslateGemma-4B-IT",
"inputSchema": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to translate"
},
"source_lang": {
"type": "string",
"description": f"Source language code: {', '.join(LANG_CODES)}"
},
"target_lang": {
"type": "string",
"description": f"Target language code: {', '.join(LANG_CODES)}"
}
},
"required": ["text", "source_lang", "target_lang"]
}
}
]
}
})
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "translate":
try:
result = translate(
arguments.get("text", ""),
arguments.get("source_lang", "en"),
arguments.get("target_lang", "en")
)
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {
"content": [{"type": "text", "text": result}]
}
})
except Exception as e:
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"error": {"code": -32000, "message": str(e)}
})
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"error": {"code": -32601, "message": f"Method not found: {method}"}
})
# Gradio interface
LANG_CHOICES = [f"{code} ({name})" for code, name in LANGUAGES.items()]
def gradio_translate(text, source, target):
src_code = source.split(" ")[0]
tgt_code = target.split(" ")[0]
return translate(text, src_code, tgt_code)
with gr.Blocks(title="TranslateGemma MCP Server") as demo:
gr.Markdown(
"""
# TranslateGemma-4B-IT MCP Server
Translation using Google's TranslateGemma-4B-IT model.
**MCP Endpoint:** `POST https://gnumanth-translategemma-4b-it.hf.space/mcp`
**55 languages** | First request loads model (~60s), then fast (~5s)
"""
)
with gr.Row():
with gr.Column():
source_lang = gr.Dropdown(choices=LANG_CHOICES, value="en (English)", label="Source Language")
input_text = gr.Textbox(label="Text to Translate", lines=5, placeholder="Enter text...")
with gr.Column():
target_lang = gr.Dropdown(choices=LANG_CHOICES, value="de (German)", label="Target Language")
output_text = gr.Textbox(label="Translation", lines=5, interactive=False)
translate_btn = gr.Button("Translate", variant="primary")
translate_btn.click(fn=gradio_translate, inputs=[input_text, source_lang, target_lang], outputs=output_text)
# Mount FastAPI to Gradio
app = gr.mount_gradio_app(app, demo, path="/", app_kwargs={"root_path": "/"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)