Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from IndicTransToolkit.processor import IndicProcessor | |
| import os | |
| import traceback | |
| app = FastAPI() | |
| # Configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running on device: {DEVICE}") | |
| # Load models | |
| MODELS = {} | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def load_model(name, repo_id): | |
| print(f"Loading {name} from {repo_id}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| repo_id, | |
| trust_remote_code=True, | |
| token=HF_TOKEN | |
| ) | |
| # Removed flash_attention_2 to fix 'NoneType' shape error on T4 | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| repo_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| token=HF_TOKEN | |
| ).to(DEVICE) | |
| model.eval() | |
| return {"tokenizer": tokenizer, "model": model} | |
| except Exception as e: | |
| print(f"Failed to load {name}: {e}") | |
| raise e | |
| # Load on startup | |
| async def startup_event(): | |
| global MODELS, ip | |
| if not HF_TOKEN: | |
| print("WARNING: HF_TOKEN environment variable is not set. Gated models may fail to load.") | |
| # 1. English to Indic | |
| MODELS["en-indic"] = load_model("en-indic", "ai4bharat/indictrans2-en-indic-dist-200M") | |
| # 2. Indic to English | |
| MODELS["indic-en"] = load_model("indic-en", "ai4bharat/indictrans2-indic-en-dist-200M") | |
| # Processor | |
| ip = IndicProcessor(inference=True) | |
| print("All models loaded successfully.") | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| source_lang: str | |
| target_lang: str | |
| async def translate(request: TranslationRequest): | |
| try: | |
| src = request.source_lang | |
| tgt = request.target_lang | |
| text = request.text | |
| if not text: | |
| return {"translated_text": ""} | |
| if src.startswith("eng"): | |
| model_key = "en-indic" | |
| elif tgt.startswith("eng"): | |
| model_key = "indic-en" | |
| else: | |
| raise HTTPException(status_code=400, detail="Direct Indic-to-Indic translation not supported.") | |
| if model_key not in MODELS: | |
| raise HTTPException(status_code=500, detail=f"Model {model_key} failed to load on startup.") | |
| print(f"Translating {model_key}: {src} -> {tgt} (len: {len(text)})") | |
| bundle = MODELS[model_key] | |
| tokenizer = bundle["tokenizer"] | |
| model = bundle["model"] | |
| # Preprocess | |
| batch = ip.preprocess_batch([text], src_lang=src, tgt_lang=tgt) | |
| # Tokenize | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| # Generate | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=False, | |
| min_length=0, | |
| max_length=2048, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| # Decode | |
| decoded_tokens = tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| # Postprocess | |
| translations = ip.postprocess_batch(decoded_tokens, lang=tgt) | |
| return {"translated_text": translations[0]} | |
| except Exception as e: | |
| traceback.print_exc() | |
| print(f"Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def read_root(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Noa AI Translator</title> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| </head> | |
| <body class="bg-gray-50 min-h-screen p-8"> | |
| <div class="max-w-2xl mx-auto bg-white rounded-xl shadow-md p-6"> | |
| <h1 class="text-2xl font-bold mb-6 text-gray-800">Noa AI Translator</h1> | |
| <div class="space-y-4"> | |
| <div class="grid grid-cols-2 gap-4"> | |
| <div> | |
| <label class="block text-sm font-medium text-gray-700 mb-1">Source Language</label> | |
| <select id="sourceLang" class="w-full border rounded-md p-2"> | |
| <option value="eng_Latn">English</option> | |
| <option value="hin_Deva">Hindi</option> | |
| <option value="tam_Taml">Tamil</option> | |
| <option value="tel_Telu">Telugu</option> | |
| <option value="kan_Knda">Kannada</option> | |
| <option value="mal_Mlym">Malayalam</option> | |
| <option value="mar_Deva">Marathi</option> | |
| <option value="guj_Gujr">Gujarati</option> | |
| <option value="ben_Beng">Bengali</option> | |
| <option value="asm_Beng">Assamese</option> | |
| <option value="pan_Guru">Punjabi</option> | |
| </select> | |
| </div> | |
| <div> | |
| <label class="block text-sm font-medium text-gray-700 mb-1">Target Language</label> | |
| <select id="targetLang" class="w-full border rounded-md p-2"> | |
| <option value="hin_Deva">Hindi</option> | |
| <option value="eng_Latn">English</option> | |
| <option value="tam_Taml">Tamil</option> | |
| <option value="tel_Telu">Telugu</option> | |
| <option value="kan_Knda">Kannada</option> | |
| <option value="mal_Mlym">Malayalam</option> | |
| <option value="mar_Deva">Marathi</option> | |
| <option value="guj_Gujr">Gujarati</option> | |
| <option value="ben_Beng">Bengali</option> | |
| <option value="asm_Beng">Assamese</option> | |
| <option value="pan_Guru">Punjabi</option> | |
| </select> | |
| </div> | |
| </div> | |
| <div> | |
| <label class="block text-sm font-medium text-gray-700 mb-1">Input Text</label> | |
| <textarea id="inputText" rows="6" class="w-full border rounded-md p-2" placeholder="Enter text here..."></textarea> | |
| </div> | |
| <button onclick="translateText()" id="translateBtn" class="w-full bg-blue-600 text-white py-2 px-4 rounded-md hover:bg-blue-700 transition-colors font-medium"> | |
| Translate | |
| </button> | |
| <div> | |
| <label class="block text-sm font-medium text-gray-700 mb-1">Translation</label> | |
| <div id="outputText" class="w-full border rounded-md p-4 min-h-[150px] bg-gray-50 whitespace-pre-wrap"></div> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| async function translateText() { | |
| const btn = document.getElementById('translateBtn'); | |
| const output = document.getElementById('outputText'); | |
| const text = document.getElementById('inputText').value; | |
| const sourceLang = document.getElementById('sourceLang').value; | |
| const targetLang = document.getElementById('targetLang').value; | |
| if (!text) return; | |
| btn.disabled = true; | |
| btn.textContent = 'Translating...'; | |
| output.textContent = ''; | |
| try { | |
| const response = await fetch('/translate', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ | |
| text: text, | |
| source_lang: sourceLang, | |
| target_lang: targetLang | |
| }) | |
| }); | |
| const data = await response.json(); | |
| if (response.ok) { | |
| output.textContent = data.translated_text; | |
| output.classList.remove('text-red-500'); | |
| } else { | |
| output.textContent = 'Error: ' + (data.detail || 'Translation failed'); | |
| output.classList.add('text-red-500'); | |
| } | |
| } catch (e) { | |
| output.textContent = 'Error: ' + e.message; | |
| output.classList.add('text-red-500'); | |
| } finally { | |
| btn.disabled = false; | |
| btn.textContent = 'Translate'; | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |