sidd1311 commited on
Commit
f4e9d54
·
verified ·
1 Parent(s): 086ac0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py CHANGED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import os
6
+ import re
7
+ from polyglot.detect import Detector
8
+
9
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
+ MODEL = "LLaMAX/LLaMAX3-8B-Alpaca"
11
+ RELATIVE_MODEL="LLaMAX/LLaMAX3-8B"
12
+
13
+ TITLE = "<h1><center>LLaMAX3-Translator</center></h1>"
14
+
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto")
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
21
+
22
+
23
+ def lang_detector(text):
24
+ min_chars = 5
25
+ if len(text) < min_chars:
26
+ return "Input text too short"
27
+ try:
28
+ detector = Detector(text).language
29
+ lang_info = str(detector)
30
+ code = re.search(r"name: (\w+)", lang_info).group(1)
31
+ return code
32
+ except Exception as e:
33
+ return f"ERROR:{str(e)}"
34
+
35
+ def Prompt_template(inst, prompt, query, src_language, trg_language):
36
+ inst = inst.format(src_language=src_language, trg_language=trg_language)
37
+ instruction = f"`{inst}`"
38
+ prompt = (
39
+ f'{prompt}'
40
+ f'### Instruction:\n{instruction}\n'
41
+ f'### Input:\n{query}\n### Response:'
42
+ )
43
+ return prompt
44
+
45
+ # Unfinished
46
+ def chunk_text():
47
+ pass
48
+
49
+ @spaces.GPU(duration=60)
50
+ def translate(
51
+ source_text: str,
52
+ source_lang: str,
53
+ target_lang: str,
54
+ inst: str,
55
+ prompt: str,
56
+ max_length: int,
57
+ temperature: float,
58
+ top_p: float,
59
+ rp: float):
60
+
61
+ print(f'Text is - {source_text}')
62
+
63
+ prompt = Prompt_template(inst, prompt, source_text, source_lang, target_lang)
64
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
65
+
66
+ generate_kwargs = dict(
67
+ input_ids=input_ids,
68
+ max_length=max_length,
69
+ do_sample=True,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ repetition_penalty=rp,
73
+ )
74
+
75
+ outputs = model.generate(**generate_kwargs)
76
+
77
+ resp = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
78
+
79
+ yield resp[len(prompt):]
80
+
81
+ CSS = """
82
+ h1 {
83
+ text-align: center;
84
+ display: block;
85
+ height: 10vh;
86
+ align-content: center;
87
+ }
88
+ footer {
89
+ visibility: hidden;
90
+ }
91
+ """
92
+
93
+ LICENSE = """
94
+ Model: <a href="https://huggingface.co/LLaMAX/LLaMAX3-8B-Alpaca">LLaMAX3-8B-Alpaca</a>
95
+ """
96
+
97
+ LANG_LIST = ['Akrikaans', 'Amharic', 'Arabic', 'Armenian', 'Assamese', 'Asturian', 'Azerbaijani', \
98
+ 'Belarusian', 'Bengali', 'Bosnian', 'Bulgarian', 'Burmese', \
99
+ 'Catalan', 'Cebuano', 'Simplified Chinese', 'Traditional Chinese', 'Croatian', 'Czech', \
100
+ 'Danish', 'Dutch', 'English', 'Estonian', 'Filipino', 'Finnish', 'French', 'Fulah', \
101
+ 'Galician', 'Ganda', 'Georgian', 'German', 'Greek', 'Gujarati', \
102
+ 'Hausa', 'Hebrew', 'Hindi', 'Hungarian', \
103
+ 'Icelandic', 'Igbo', 'Indonesian', 'Irish', 'Italian', \
104
+ 'Japanese', 'Javanese', \
105
+ 'Kabuverdianu', 'Kamba', 'Kannada', 'Kazakh', 'Khmer', 'Korean', 'Kyrgyz', \
106
+ 'Lao', 'Latvian', 'Lingala', 'Lithuanian', 'Luo', 'Luxembourgish', \
107
+ 'Macedonian', 'Malay', 'Malayalam', 'Maltese', 'Maori', 'Marathi', 'Mongolian', \
108
+ 'Nepali', 'Northern', 'Norwegian', 'Nyanja', \
109
+ 'Occitan', 'Oriya', 'Oromo', \
110
+ 'Pashto', 'Persian', 'Polish', 'Portuguese', 'Punjabi', \
111
+ 'Romanian', 'Russian', \
112
+ 'Serbian', 'Shona', 'Sindhi', 'Slovak', 'Slovenian', 'Somali', 'Sorani', 'Spanish', 'Swahili', 'Swedish', \
113
+ 'Tajik', 'Tamil', 'Telugu', 'Thai', 'Turkish', \
114
+ 'Ukrainian', 'Umbundu', 'Urdu', 'Uzbek', \
115
+ 'Vietnamese', 'Welsh', 'Wolof', 'Xhosa', 'Yoruba', 'Zulu']
116
+
117
+ chatbot = gr.Chatbot(height=600)
118
+
119
+ with gr.Blocks(theme="soft", css=CSS) as demo:
120
+ gr.Markdown(TITLE)
121
+ with gr.Row():
122
+ with gr.Column(scale=1):
123
+ source_lang = gr.Textbox(
124
+ label="Source Lang(Auto-Detect)",
125
+ value="English",
126
+ )
127
+ target_lang = gr.Dropdown(
128
+ label="Target Lang",
129
+ value="Spanish",
130
+ choices=LANG_LIST,
131
+ )
132
+ max_length = gr.Slider(
133
+ label="Max Length",
134
+ minimum=512,
135
+ maximum=8192,
136
+ value=4096,
137
+ step=8,
138
+ )
139
+ temperature = gr.Slider(
140
+ label="Temperature",
141
+ minimum=0,
142
+ maximum=1,
143
+ value=0.3,
144
+ step=0.1,
145
+ )
146
+ top_p = gr.Slider(
147
+ minimum=0.0,
148
+ maximum=1.0,
149
+ step=0.1,
150
+ value=1.0,
151
+ label="top_p",
152
+ )
153
+ rp = gr.Slider(
154
+ minimum=0.0,
155
+ maximum=2.0,
156
+ step=0.1,
157
+ value=1.2,
158
+ label="Repetition penalty",
159
+ )
160
+ with gr.Accordion("Advanced Options", open=False):
161
+ inst = gr.Textbox(
162
+ label="Instruction",
163
+ value="Translate the following sentences from {src_language} to {trg_language}.",
164
+ lines=3,
165
+ )
166
+ prompt = gr.Textbox(
167
+ label="Prompt",
168
+ value=""" 'Below is an instruction that describes a task, paired with an input that provides further context. '
169
+ 'Write a response that appropriately completes the request.\n' """,
170
+ lines=8,
171
+ )
172
+
173
+ with gr.Column(scale=4):
174
+ source_text = gr.Textbox(
175
+ label="Source Text",
176
+ value="LLaMAX is a language model with powerful multilingual capabilities without loss instruction-following capabilities. "+\
177
+ "LLaMAX supports translation between more than 100 languages, "+\
178
+ "surpassing the performance of similarly scaled LLMs.",
179
+ lines=10,
180
+ )
181
+ output_text = gr.Textbox(
182
+ label="Output Text",
183
+ lines=10,
184
+ show_copy_button=True,
185
+ )
186
+ with gr.Row():
187
+ submit = gr.Button(value="Submit")
188
+ clear = gr.ClearButton([source_text, output_text])
189
+ gr.Markdown(LICENSE)
190
+
191
+ source_text.change(lang_detector, source_text, source_lang)
192
+ submit.click(fn=translate, inputs=[source_text, source_lang, target_lang, inst, prompt, max_length, temperature, top_p, rp], outputs=[output_text])
193
+
194
+
195
+ if __name__ == "__main__":
196
+ demo.launch()