Spaces:
Runtime error
Runtime error
| # app.py | |
| import streamlit as st | |
| from fastapi import FastAPI | |
| from typing import List | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from IndicTransToolkit import IndicProcessor | |
| import json | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize models and processors | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True | |
| ) | |
| ip = IndicProcessor(inference=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(DEVICE) | |
| def translate_text(sentences: List[str], target_lang: str): | |
| try: | |
| src_lang = "eng_Latn" | |
| batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| translations = ip.postprocess_batch(generated_tokens, lang=target_lang) | |
| return { | |
| "translations": translations, | |
| "source_language": src_lang, | |
| "target_language": target_lang, | |
| } | |
| except Exception as e: | |
| raise Exception(f"Translation failed: {str(e)}") | |
| # FastAPI routes | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def translate_endpoint(sentences: List[str], target_lang: str): | |
| try: | |
| result = translate_text(sentences=sentences, target_lang=target_lang) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # # Streamlit interface | |
| # def main(): | |
| # st.title("Indic Language Translator") | |
| # # Input text | |
| # text_input = st.text_area("Enter text to translate:", "Hello, how are you?") | |
| # # Language selection | |
| # target_languages = { | |
| # "Hindi": "hin_Deva", | |
| # "Bengali": "ben_Beng", | |
| # "Tamil": "tam_Taml", | |
| # "Telugu": "tel_Telu", | |
| # "Marathi": "mar_Deva", | |
| # "Gujarati": "guj_Gujr", | |
| # "Kannada": "kan_Knda", | |
| # "Malayalam": "mal_Mlym", | |
| # "Punjabi": "pan_Guru", | |
| # "Odia": "ori_Orya", | |
| # } | |
| # target_lang = st.selectbox( | |
| # "Select target language:", options=list(target_languages.keys()) | |
| # ) | |
| # if st.button("Translate"): | |
| # try: | |
| # result = translate_text( | |
| # sentences=[text_input], target_lang=target_languages[target_lang] | |
| # ) | |
| # st.success("Translation:") | |
| # st.write(result["translations"][0]) | |
| # except Exception as e: | |
| # st.error(f"Translation failed: {str(e)}") | |
| # # Add API documentation | |
| # st.markdown("---") | |
| # st.header("API Documentation") | |
| # st.markdown( | |
| # """ | |
| # To use the translation API, send POST requests to: | |
| # ``` | |
| # https://darshankr-trans-en-indic.hf.space/translate | |
| # ``` | |
| # Request body format: | |
| # ```json | |
| # { | |
| # "sentences": ["Your text here"], | |
| # "target_lang": "hin_Deva" | |
| # } | |
| # ``` | |
| # """ | |
| # ) | |
| # st.markdown("Available target languages:") | |
| # for lang, code in target_languages.items(): | |
| # st.markdown(f"- {lang}: `{code}`") | |
| # if __name__ == "__main__": | |
| # # Run both Streamlit and FastAPI | |
| # import threading | |
| # def run_fastapi(): | |
| # uvicorn.run(api, host="0.0.0.0", port=8000) | |
| # # Start FastAPI in a separate thread | |
| # api_thread = threading.Thread(target=run_fastapi) | |
| # api_thread.start() | |
| # # Run Streamlit | |
| # main() | |