Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from transformers import AutoModelForSeq2SeqLM | |
| from IndicTransToolkit import IndicProcessor | |
| from typing import List | |
| import os | |
| # Set the HF_HOME environment variable to a writable directory | |
| os.environ["HF_HOME"] = "/app/cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/app/cache" | |
| model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True) | |
| ip = IndicProcessor(inference=True) | |
| app = FastAPI() | |
| # Define request body with Pydantic | |
| class InputData(BaseModel): | |
| sentences: List[str] | |
| target_lang: str | |
| # API endpoint to receive input and return predictions | |
| async def predict(input_data: InputData): | |
| try: | |
| result = model(input_data.text) | |
| return {"output": result} | |
| src_lang, tgt_lang = "eng_Latn", input_data.target_lang | |
| batch = ip.preprocess_batch( | |
| input_sentences, | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| ) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Tokenize the sentences and generate input encodings | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| # Generate translations using the model | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| # Decode the generated tokens into text | |
| with tokenizer.as_target_tokenizer(): | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| # Postprocess the translations, including entity replacement | |
| translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
| return {"output": translations} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |