Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig | |
| from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor | |
| from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer | |
| from peft import PeftModel | |
| from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang | |
| DIRECTION = "en-indic" | |
| QUANTIZATION = None | |
| IP = IndicProcessor(inference=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HALF = True if torch.cuda.is_available() else False | |
| def initialize_model_and_tokenizer(): | |
| if QUANTIZATION == "4-bit": | |
| qconfig = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| elif QUANTIZATION == "8-bit": | |
| qconfig = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_use_double_quant=True, | |
| bnb_8bit_compute_dtype=torch.bfloat16, | |
| ) | |
| else: | |
| qconfig = None | |
| tokenizer = IndicTransTokenizer(direction=DIRECTION) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_repo_id, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| quantization_config=qconfig, | |
| ) | |
| model2 = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_repo_id, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| quantization_config=qconfig, | |
| ) | |
| if qconfig == None: | |
| model = model.to(DEVICE) | |
| model2 = model2.to(DEVICE) | |
| model.eval() | |
| model2.eval() | |
| lora_model = PeftModel.from_pretrained(model2, lora_repo_id) | |
| return tokenizer, model, lora_model | |
| def batch_translate(input_sentences, model, tokenizer): | |
| translations = [] | |
| for i in range(0, len(input_sentences), batch_size): | |
| batch = input_sentences[i : i + batch_size] | |
| # Preprocess the batch and extract entity mappings | |
| batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) | |
| # Tokenize the batch and generate input encodings | |
| inputs = tokenizer( | |
| batch, | |
| src=True, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| # Generate translations using the model | |
| with torch.inference_mode(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| # Decode the generated tokens into text | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), src=False | |
| ) | |
| # Postprocess the translations, including entity replacement | |
| translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang) | |
| del inputs | |
| return translations | |