|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
import re |
|
|
from polyglot.detect import Detector |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
MODEL = "LLaMAX/LLaMAX3-8B-Alpaca" |
|
|
RELATIVE_MODEL="LLaMAX/LLaMAX3-8B" |
|
|
|
|
|
TITLE = "<h1><center>LLaMAX3-Translator</center></h1>" |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL) |
|
|
|
|
|
|
|
|
def lang_detector(text): |
|
|
min_chars = 5 |
|
|
if len(text) < min_chars: |
|
|
return "Input text too short" |
|
|
try: |
|
|
detector = Detector(text).language |
|
|
lang_info = str(detector) |
|
|
code = re.search(r"name: (\w+)", lang_info).group(1) |
|
|
return code |
|
|
except Exception as e: |
|
|
return f"ERROR:{str(e)}" |
|
|
|
|
|
def Prompt_template(inst, prompt, query, src_language, trg_language): |
|
|
inst = inst.format(src_language=src_language, trg_language=trg_language) |
|
|
instruction = f"`{inst}`" |
|
|
prompt = ( |
|
|
f'{prompt}' |
|
|
f'### Instruction:\n{instruction}\n' |
|
|
f'### Input:\n{query}\n### Response:' |
|
|
) |
|
|
return prompt |
|
|
|
|
|
|
|
|
def chunk_text(): |
|
|
pass |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def translate( |
|
|
source_text: str, |
|
|
source_lang: str, |
|
|
target_lang: str, |
|
|
inst: str, |
|
|
prompt: str, |
|
|
max_length: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
rp: float): |
|
|
|
|
|
print(f'Text is - {source_text}') |
|
|
|
|
|
prompt = Prompt_template(inst, prompt, source_text, source_lang, target_lang) |
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) |
|
|
|
|
|
generate_kwargs = dict( |
|
|
input_ids=input_ids, |
|
|
max_length=max_length, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=rp, |
|
|
) |
|
|
|
|
|
outputs = model.generate(**generate_kwargs) |
|
|
|
|
|
resp = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
|
|
yield resp[len(prompt):] |
|
|
|
|
|
CSS = """ |
|
|
h1 { |
|
|
text-align: center; |
|
|
display: block; |
|
|
height: 10vh; |
|
|
align-content: center; |
|
|
} |
|
|
footer { |
|
|
visibility: hidden; |
|
|
} |
|
|
""" |
|
|
|
|
|
LICENSE = """ |
|
|
Model: <a href="https://huggingface.co/LLaMAX/LLaMAX3-8B-Alpaca">LLaMAX3-8B-Alpaca</a> |
|
|
""" |
|
|
|
|
|
LANG_LIST = [ |
|
|
'Assamese', 'Bengali', 'Gujarati', 'Hindi', 'Kannada', 'Kashmiri', 'Konkani', |
|
|
'Malayalam', 'Manipuri', 'Marathi', 'Nepali', 'Oriya', 'Punjabi', |
|
|
'Sanskrit', 'Sindhi', 'Tamil', 'Telugu', 'Urdu', 'English' |
|
|
] |
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(height=600) |
|
|
|
|
|
with gr.Blocks(theme="soft", css=CSS) as demo: |
|
|
gr.Markdown(TITLE) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
source_lang = gr.Textbox( |
|
|
label="Source Lang(Auto-Detect)", |
|
|
value="English", |
|
|
) |
|
|
target_lang = gr.Dropdown( |
|
|
label="Target Lang", |
|
|
value="Spanish", |
|
|
choices=LANG_LIST, |
|
|
) |
|
|
max_length = gr.Slider( |
|
|
label="Max Length", |
|
|
minimum=512, |
|
|
maximum=8192, |
|
|
value=4096, |
|
|
step=8, |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
value=0.3, |
|
|
step=0.1, |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.1, |
|
|
value=1.0, |
|
|
label="top_p", |
|
|
) |
|
|
rp = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
step=0.1, |
|
|
value=1.2, |
|
|
label="Repetition penalty", |
|
|
) |
|
|
with gr.Accordion("Advanced Options", open=False): |
|
|
inst = gr.Textbox( |
|
|
label="Instruction", |
|
|
value="Translate the following sentences from {src_language} to {trg_language}.", |
|
|
lines=3, |
|
|
) |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
value=""" 'Below is an instruction that describes a task, paired with an input that provides further context. ' |
|
|
'Write a response that appropriately completes the request.\n' """, |
|
|
lines=8, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=4): |
|
|
source_text = gr.Textbox( |
|
|
label="Source Text", |
|
|
value="LLaMAX is a language model with powerful multilingual capabilities without loss instruction-following capabilities. "+\ |
|
|
"LLaMAX supports translation between more than 100 languages, "+\ |
|
|
"surpassing the performance of similarly scaled LLMs.", |
|
|
lines=10, |
|
|
) |
|
|
output_text = gr.Textbox( |
|
|
label="Output Text", |
|
|
lines=10, |
|
|
show_copy_button=True, |
|
|
) |
|
|
with gr.Row(): |
|
|
submit = gr.Button(value="Submit") |
|
|
clear = gr.ClearButton([source_text, output_text]) |
|
|
gr.Markdown(LICENSE) |
|
|
|
|
|
source_text.change(lang_detector, source_text, source_lang) |
|
|
submit.click(fn=translate, inputs=[source_text, source_lang, target_lang, inst, prompt, max_length, temperature, top_p, rp], outputs=[output_text]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |