translate-gemma / app.py
fantos's picture
Update app.py
9e1be59 verified
import os
os.environ["GRADIO_SSR_MODE"] = "false"
import torch
import spaces
import gradio as gr
from transformers import AutoModelForImageTextToText, AutoProcessor
from huggingface_hub import login
from fastapi import Request
from fastapi.responses import JSONResponse
# Login with HF token if available
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Language codes
LANGUAGES = {
"en": "English", "de": "German", "fr": "French", "es": "Spanish",
"it": "Italian", "pt": "Portuguese", "nl": "Dutch", "pl": "Polish",
"cs": "Czech", "ru": "Russian", "uk": "Ukrainian", "zh": "Chinese",
"ja": "Japanese", "ko": "Korean", "ar": "Arabic", "hi": "Hindi",
"bn": "Bengali", "tr": "Turkish", "vi": "Vietnamese", "th": "Thai",
"id": "Indonesian", "ms": "Malay", "sv": "Swedish", "no": "Norwegian",
"da": "Danish", "fi": "Finnish", "el": "Greek", "he": "Hebrew",
"ro": "Romanian", "hu": "Hungarian", "bg": "Bulgarian", "hr": "Croatian",
"sk": "Slovak", "sl": "Slovenian", "sr": "Serbian", "lt": "Lithuanian",
"lv": "Latvian", "et": "Estonian", "sw": "Swahili", "ta": "Tamil",
"te": "Telugu", "mr": "Marathi", "gu": "Gujarati", "kn": "Kannada",
"ml": "Malayalam", "pa": "Punjabi", "ur": "Urdu", "fa": "Persian",
"fil": "Filipino", "ca": "Catalan", "gl": "Galician", "eu": "Basque",
"cy": "Welsh", "ga": "Irish",
}
model_id = "google/translategemma-4b-it"
print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_id)
print("Processor loaded!")
model = None
def load_model():
global model
if model is None:
print("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
).to(device).eval()
print(f"Model loaded on {device.upper()}!")
return model
@spaces.GPU(duration=120)
def translate(text: str, source_lang: str, target_lang: str) -> str:
if not text or not text.strip():
return ""
m = load_model()
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}
]
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(device)
with torch.inference_mode():
generation = m.generate(**inputs, max_new_tokens=1024, do_sample=False)
input_len = inputs["input_ids"].shape[1]
output = processor.decode(generation[0][input_len:], skip_special_tokens=True)
return output.strip()
LANG_CODES = list(LANGUAGES.keys())
LANG_CHOICES = [f"{code} ({name})" for code, name in LANGUAGES.items()]
def gradio_translate(text, source, target):
src_code = source.split(" ")[0]
tgt_code = target.split(" ")[0]
return translate(text, src_code, tgt_code)
# ✅ Gradio UI 정의
with gr.Blocks(title="TranslateGemma") as demo:
gr.HTML("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Bangers&family=Comic+Neue:wght@400;700&display=swap');
* { font-family: 'Comic Neue', cursive !important; }
body, .gradio-container {
background: linear-gradient(135deg, #ffeb3b 0%, #fff176 100%) !important;
min-height: 100vh;
}
#col-container {
margin: 0 auto;
max-width: 1100px;
padding: 25px 20px;
}
.header-box {
background: #ffffff;
border: 5px solid #000000;
border-radius: 30px;
padding: 25px 30px;
text-align: center;
margin-bottom: 25px;
box-shadow: 8px 8px 0 #000000;
}
.badge-row {
display: flex;
justify-content: center;
margin-bottom: 12px;
}
.title-text {
font-family: 'Bangers', cursive !important;
font-size: 3rem !important;
color: #ff1744 !important;
margin: 8px 0 !important;
text-shadow: 3px 3px 0 #000000;
letter-spacing: 3px;
}
.subtitle-text {
color: #1565c0 !important;
font-size: 1.1rem !important;
font-weight: 700 !important;
text-transform: uppercase;
}
.model-badge {
display: inline-block;
background: #2196f3;
color: white;
padding: 5px 15px;
border: 3px solid #000;
border-radius: 20px;
font-weight: 700;
box-shadow: 3px 3px 0 #000;
margin-top: 8px;
}
textarea {
background: #fffde7 !important;
border: 3px solid #000000 !important;
border-radius: 15px !important;
color: #000000 !important;
font-size: 1.05rem !important;
font-weight: 600 !important;
padding: 15px !important;
min-height: 150px !important;
}
textarea:focus {
border-color: #ff1744 !important;
box-shadow: 0 0 0 3px rgba(255, 23, 68, 0.3) !important;
}
.pow-btn {
background: linear-gradient(180deg, #ff1744 0%, #d50000 100%) !important;
border: 4px solid #000000 !important;
border-radius: 15px !important;
color: #ffffff !important;
font-family: 'Bangers', cursive !important;
font-size: 1.5rem !important;
padding: 18px 40px !important;
box-shadow: 5px 5px 0 #000000 !important;
text-shadow: 2px 2px 0 #000000;
width: 100%;
margin-top: 15px;
}
.pow-btn:hover {
transform: translate(-3px, -3px) !important;
box-shadow: 8px 8px 0 #000000 !important;
}
label {
color: #000000 !important;
font-weight: 700 !important;
text-transform: uppercase;
}
.info-box {
background: #e3f2fd;
border: 3px solid #000;
border-radius: 15px;
padding: 12px 20px;
margin-top: 20px;
box-shadow: 4px 4px 0 #000;
}
.info-box p {
margin: 5px 0;
font-weight: 600;
color: #1565c0 !important;
}
.swap-icon {
font-size: 2rem;
color: #ff1744;
text-shadow: 2px 2px 0 #000;
display: flex;
align-items: center;
justify-content: center;
}
</style>
""")
with gr.Column(elem_id="col-container"):
gr.HTML("""
<div class="header-box">
<div class="badge-row">
<a href="https://www.humangen.ai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Free%20AI%20HUB&message=humangen.ai&color=%23ff1744&labelColor=%23000000&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge">
</a>
</div>
<h1 class="title-text">🌍 TRANSLATE GEMMA</h1>
<p class="subtitle-text">⚡ AI-Powered Translation for 55 Languages ⚡</p>
<span class="model-badge">🤖 MODEL: translategemma-4b-it</span>
</div>
""")
with gr.Row():
source_lang = gr.Dropdown(
choices=LANG_CHOICES,
value="en (English)",
label="📤 SOURCE LANGUAGE"
)
gr.HTML('<div class="swap-icon">⚡➡️⚡</div>')
target_lang = gr.Dropdown(
choices=LANG_CHOICES,
value="ko (Korean)",
label="📥 TARGET LANGUAGE"
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="💬 INPUT",
lines=8,
placeholder="Type or paste your text here..."
)
with gr.Column():
output_text = gr.Textbox(
label="✨ OUTPUT",
lines=8,
interactive=False,
placeholder="Translation will appear here..."
)
translate_btn = gr.Button("💥 POW! TRANSLATE 💥", variant="primary", elem_classes="pow-btn")
gr.HTML("""
<div class="info-box">
<p>🔗 <strong>MCP Endpoint:</strong> POST /mcp</p>
<p>⏱️ <strong>Note:</strong> First request loads model (~60s), then fast (~5s)</p>
</div>
""")
translate_btn.click(
fn=gradio_translate,
inputs=[input_text, source_lang, target_lang],
outputs=output_text
)
# ✅ Gradio 앱 실행 후 FastAPI 라우트 추가
app = demo.app # Gradio 내부 FastAPI 앱 가져오기
@app.get("/health")
async def health_check():
return {"status": "ok"}
@app.post("/mcp")
async def mcp_handler(request: Request):
body = await request.json()
method = body.get("method", "")
params = body.get("params", {})
msg_id = body.get("id")
if method == "initialize":
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {
"name": "translategemma-mcp",
"version": "1.0.0"
}
}
})
elif method == "tools/list":
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {
"tools": [
{
"name": "translate",
"description": "Translate text between 55 languages using TranslateGemma-4B-IT",
"inputSchema": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "The text to translate"},
"source_lang": {"type": "string", "description": f"Source language code: {', '.join(LANG_CODES)}"},
"target_lang": {"type": "string", "description": f"Target language code: {', '.join(LANG_CODES)}"}
},
"required": ["text", "source_lang", "target_lang"]
}
}
]
}
})
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "translate":
try:
result = translate(
arguments.get("text", ""),
arguments.get("source_lang", "en"),
arguments.get("target_lang", "en")
)
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"result": {"content": [{"type": "text", "text": result}]}
})
except Exception as e:
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"error": {"code": -32000, "message": str(e)}
})
return JSONResponse({
"jsonrpc": "2.0",
"id": msg_id,
"error": {"code": -32601, "message": f"Method not found: {method}"}
})
# ✅ Gradio launch로 실행 (FastAPI mount 대신)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)