Spaces:
Runtime error
Runtime error
| import logging | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, NllbTokenizer | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', | |
| datefmt='%H:%M:%S' | |
| ) | |
| logger = logging.getLogger() | |
| AVAILABLE_MODELS = { | |
| "NLLB for transliteration": "kesha-humonen/tr-eng_checkpoint-8556", | |
| "NLLB for hieroglyphs": "kesha-humonen/hi-eng_dpo_checkpoint-3342" | |
| } | |
| LANGUAGES = { | |
| "English": "eng_Latn", | |
| "German": "deu_Latn" | |
| } | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| current_model = None | |
| current_tokenizer = None | |
| current_model_name = None | |
| def unload_model(): | |
| global current_model | |
| if current_model is not None: | |
| logger.info(f"Unloading current model: {current_model.name_or_path}") | |
| del current_model | |
| torch.cuda.empty_cache() | |
| current_model = None | |
| def load_model(model_name: str): | |
| global current_model, current_tokenizer, current_model_name | |
| unload_model() | |
| logger.info(f"Loading model: {model_name}") | |
| current_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| AVAILABLE_MODELS[model_name] | |
| ).to(device) | |
| current_tokenizer = NllbTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
| current_model_name = model_name | |
| return "The model has been uploaded successfully!" | |
| def generate(input_texts: str, model_name: str, language: str) -> str: | |
| """Генерирует текст на основе входных данных""" | |
| if current_model is None or current_tokenizer is None: | |
| return "Please select and upload the model first." | |
| if model_name == "NLLB for transliteration": | |
| current_tokenizer.src_lang = 'egy_Tnt' | |
| elif model_name == "NLLB for hieroglyphs": | |
| current_tokenizer.src_lang = 'egy_Hiero' | |
| encoded_inputs = current_tokenizer( | |
| input_texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(current_model.device) | |
| with torch.no_grad(): | |
| # Устанавливаем язык для вывода на основе выбора пользователя | |
| forced_bos_token_id = current_tokenizer.convert_tokens_to_ids(LANGUAGES[language]) | |
| generated_tokens = current_model.generate( | |
| **encoded_inputs, | |
| forced_bos_token_id=forced_bos_token_id, | |
| num_beams=4, | |
| early_stopping=True, | |
| repetition_penalty=3.0 | |
| ) | |
| output_text = current_tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
| response = output_text.replace(f'{LANGUAGES[language]} ', '') | |
| return response | |
| def predict(model_choice, message, language): | |
| global current_model_name | |
| if current_model is None or model_choice != current_model_name: | |
| load_model(model_choice) | |
| return generate(message, model_choice, language) | |
| demo = gr.Interface( | |
| allow_flagging="never", | |
| fn=predict, | |
| inputs=[ | |
| gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| label="Select a model", | |
| value=list(AVAILABLE_MODELS.keys())[0] | |
| ), | |
| gr.Textbox( | |
| label="Enter the sentence using transliteration or hieroglyphs.", | |
| placeholder="", | |
| lines=3 | |
| ), | |
| gr.Dropdown( | |
| choices=list(LANGUAGES.keys()), | |
| label="Select output language", | |
| value="English" | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="Translation", | |
| lines=10 | |
| ) | |
| ], | |
| title="", | |
| examples=[ | |
| ["NLLB for transliteration", "wn sbte nb ẖn =f", "English"], | |
| ["NLLB for hieroglyphs", "𓅭 𓆑 𓉐𓉻𓌕𓏌 𓋴𓌉𓂖 𓊪𓏏𓎛𓊵𓏏𓊪", "English"], | |
| ["NLLB for transliteration", "m wṯs jb =k n z", "German"], | |
| ["NLLB for hieroglyphs", "𓍹𓅃𓇋𓂓𓅱𓍺𓌳𓍘𓃫𓌸𓊖", "German"], | |
| ], | |
| theme='base', | |
| # :root { | |
| # --bg: rgb(22,28,38) !important; | |
| # --bg-dark: rgb(22,28,38) !important; | |
| # --col: #f4f4f5 !important; | |
| # --col-dark: #f4f4f5 !important; | |
| # } | |
| # body, .gradio-container, .gradio-container > div, .gradio-container .panel, .gradio-container .output, .gradio-container .input { | |
| # background: rgb(22,28,38) !important; | |
| # color: #f4f4f5 !important; | |
| # margin: 0; | |
| # padding: 0; | |
| # height: 100%; | |
| # width: 100%; | |
| # overflow: hidden; | |
| # } | |
| css=""" | |
| button { | |
| background-color: rgb(15, 138, 129) !important; | |
| color: white !important; | |
| } | |
| button:hover { | |
| background-color: rgb(15, 138, 129) !important; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| """, | |
| ) | |
| if __name__ == "__main__": | |
| default_model = list(AVAILABLE_MODELS.keys())[0] | |
| load_model(default_model) | |
| demo.launch(share=False) |