Spaces:
Running
Running
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| from huggingface_hub import login | |
| # Login with HF token if available | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # Language codes | |
| LANGUAGES = { | |
| "en": "English", "de": "German", "fr": "French", "es": "Spanish", | |
| "it": "Italian", "pt": "Portuguese", "nl": "Dutch", "pl": "Polish", | |
| "cs": "Czech", "ru": "Russian", "uk": "Ukrainian", "zh": "Chinese", | |
| "ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi", | |
| "bn": "Bengali", "tr": "Turkish", "vi": "Vietnamese", "th": "Thai", | |
| "id": "Indonesian", "ms": "Malay", "sv": "Swedish", "no": "Norwegian", | |
| "da": "Danish", "fi": "Finnish", "el": "Greek", "he": "Hebrew", | |
| "ro": "Romanian", "hu": "Hungarian", "bg": "Bulgarian", "hr": "Croatian", | |
| "sk": "Slovak", "sl": "Slovenian", "sr": "Serbian", "lt": "Lithuanian", | |
| "lv": "Latvian", "et": "Estonian", "sw": "Swahili", "ta": "Tamil", | |
| "te": "Telugu", "mr": "Marathi", "gu": "Gujarati", "kn": "Kannada", | |
| "ml": "Malayalam", "pa": "Punjabi", "ur": "Urdu", "fa": "Persian", | |
| "fil": "Filipino", "ca": "Catalan", "gl": "Galician", "eu": "Basque", | |
| "cy": "Welsh", "ga": "Irish", | |
| } | |
| model_id = "google/translategemma-4b-it" | |
| # Load processor at startup (lightweight) | |
| print("Loading processor...") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| print("Processor loaded!") | |
| # Model will be loaded on first GPU call | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| print("Loading model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| ).to(device).eval() | |
| print(f"Model loaded on {device.upper()}!") | |
| return model | |
| def translate(text: str, source_lang: str, target_lang: str) -> str: | |
| """Translate text from source language to target language.""" | |
| if not text or not text.strip(): | |
| return "" | |
| m = load_model() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "source_lang_code": source_lang, | |
| "target_lang_code": target_lang, | |
| "text": text, | |
| } | |
| ], | |
| } | |
| ] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.inference_mode(): | |
| generation = m.generate(**inputs, max_new_tokens=1024, do_sample=False) | |
| input_len = inputs["input_ids"].shape[1] | |
| output = processor.decode(generation[0][input_len:], skip_special_tokens=True) | |
| return output.strip() | |
| # MCP HTTP endpoint | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| app = FastAPI() | |
| async def health_check(): | |
| return {"status": "ok"} | |
| LANG_CODES = list(LANGUAGES.keys()) | |
| async def mcp_handler(request: Request): | |
| """Handle MCP JSON-RPC messages via HTTP POST.""" | |
| body = await request.json() | |
| method = body.get("method", "") | |
| params = body.get("params", {}) | |
| msg_id = body.get("id") | |
| if method == "initialize": | |
| return JSONResponse({ | |
| "jsonrpc": "2.0", | |
| "id": msg_id, | |
| "result": { | |
| "protocolVersion": "2024-11-05", | |
| "capabilities": {"tools": {}}, | |
| "serverInfo": { | |
| "name": "translategemma-mcp", | |
| "version": "1.0.0" | |
| } | |
| } | |
| }) | |
| elif method == "tools/list": | |
| return JSONResponse({ | |
| "jsonrpc": "2.0", | |
| "id": msg_id, | |
| "result": { | |
| "tools": [ | |
| { | |
| "name": "translate", | |
| "description": "Translate text between 55 languages using TranslateGemma-4B-IT", | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "text": { | |
| "type": "string", | |
| "description": "The text to translate" | |
| }, | |
| "source_lang": { | |
| "type": "string", | |
| "description": f"Source language code: {', '.join(LANG_CODES)}" | |
| }, | |
| "target_lang": { | |
| "type": "string", | |
| "description": f"Target language code: {', '.join(LANG_CODES)}" | |
| } | |
| }, | |
| "required": ["text", "source_lang", "target_lang"] | |
| } | |
| } | |
| ] | |
| } | |
| }) | |
| elif method == "tools/call": | |
| tool_name = params.get("name") | |
| arguments = params.get("arguments", {}) | |
| if tool_name == "translate": | |
| try: | |
| result = translate( | |
| arguments.get("text", ""), | |
| arguments.get("source_lang", "en"), | |
| arguments.get("target_lang", "en") | |
| ) | |
| return JSONResponse({ | |
| "jsonrpc": "2.0", | |
| "id": msg_id, | |
| "result": { | |
| "content": [{"type": "text", "text": result}] | |
| } | |
| }) | |
| except Exception as e: | |
| return JSONResponse({ | |
| "jsonrpc": "2.0", | |
| "id": msg_id, | |
| "error": {"code": -32000, "message": str(e)} | |
| }) | |
| return JSONResponse({ | |
| "jsonrpc": "2.0", | |
| "id": msg_id, | |
| "error": {"code": -32601, "message": f"Method not found: {method}"} | |
| }) | |
| # Gradio interface | |
| LANG_CHOICES = [f"{code} ({name})" for code, name in LANGUAGES.items()] | |
| def gradio_translate(text, source, target): | |
| src_code = source.split(" ")[0] | |
| tgt_code = target.split(" ")[0] | |
| return translate(text, src_code, tgt_code) | |
| with gr.Blocks(title="TranslateGemma MCP Server") as demo: | |
| gr.Markdown( | |
| """ | |
| # TranslateGemma-4B-IT MCP Server | |
| Translation using Google's TranslateGemma-4B-IT model. | |
| **MCP Endpoint:** `POST https://gnumanth-translategemma-4b-it.hf.space/mcp` | |
| **55 languages** | First request loads model (~60s), then fast (~5s) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_lang = gr.Dropdown(choices=LANG_CHOICES, value="en (English)", label="Source Language") | |
| input_text = gr.Textbox(label="Text to Translate", lines=5, placeholder="Enter text...") | |
| with gr.Column(): | |
| target_lang = gr.Dropdown(choices=LANG_CHOICES, value="de (German)", label="Target Language") | |
| output_text = gr.Textbox(label="Translation", lines=5, interactive=False) | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| translate_btn.click(fn=gradio_translate, inputs=[input_text, source_lang, target_lang], outputs=output_text) | |
| # Mount FastAPI to Gradio | |
| app = gr.mount_gradio_app(app, demo, path="/", app_kwargs={"root_path": "/"}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |