| import os | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from gradio.components import Textbox, Button, Slider, Checkbox | |
| from AinaTheme import theme | |
| from huggingface_hub import InferenceClient | |
| from urllib.error import HTTPError | |
| load_dotenv() | |
| def generate(prompt, model_parameters): | |
| try: | |
| output = client.text_generation(prompt, **model_parameters, return_full_text=True) | |
| return output | |
| except HTTPError as err: | |
| if err.code == 400: | |
| gr.Warning("The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET.") | |
| except: | |
| gr.Warning('Inference endpoint is not available right now. Please try again later.') | |
| client = InferenceClient( | |
| os.environ.get("HF_INFERENCE_ENDPOINT_URL"), | |
| token=os.environ.get("HF_INFERENCE_ENDPOINT_TOKEN") | |
| ) | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100)) | |
| MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=100)) | |
| SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True) == "True" | |
| def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): | |
| if input_.strip() == "": | |
| gr.Warning('Not possible to inference an empty input') | |
| return None | |
| model_parameters = { | |
| "max_new_tokens": max_new_tokens, | |
| "repetition_penalty": repetition_penalty, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "do_sample": do_sample, | |
| "temperature": temperature | |
| } | |
| output = generate(input_, model_parameters) | |
| return output | |
| def change_interactive(text): | |
| if len(text.strip()) > MAX_INPUT_CHARACTERS: | |
| return gr.update(interactive = True), gr.update(interactive = False) | |
| if (len(text) == 0): | |
| return gr.update(interactive = True), gr.update(interactive = False) | |
| return gr.update(interactive = True), gr.update(interactive = True) | |
| def clear(): | |
| return ( | |
| None, | |
| None, | |
| gr.update(value=MAX_NEW_TOKENS), | |
| gr.update(value=1.2), | |
| gr.update(value=50), | |
| gr.update(value=0.95), | |
| gr.update(value=True), | |
| gr.update(value=0.5), | |
| ) | |
| def gradio_app(): | |
| with gr.Blocks(theme=theme) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=0.1): | |
| gr.Image("ginesta_small.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """# AIMestre | |
| Basat en el model [Flor](https://huggingface.co/projecte-aina/FLOR-6.3B) del projecte AINA. | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(variant="panel"): | |
| placeholder_max_token = Textbox( | |
| visible=False, | |
| interactive=False, | |
| value= MAX_INPUT_CHARACTERS | |
| ) | |
| input_ = Textbox( | |
| lines=11, | |
| label="Posa aquí el teu escrit en català.", | |
| placeholder="e.g. El mercat del barri és fantàstic hi pots trobar." | |
| ) | |
| with gr.Row(variant="panel", equal_height=True): | |
| gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""") | |
| gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span> / {MAX_INPUT_CHARACTERS}</span>""") | |
| with gr.Row(variant="panel"): | |
| with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI): | |
| max_new_tokens = Slider( | |
| minimum=1, | |
| maximum=200, | |
| step=1, | |
| value=MAX_NEW_TOKENS, | |
| label="Max tokens" | |
| ) | |
| repetition_penalty = Slider( | |
| minimum=0.1, | |
| maximum=10, | |
| step=0.1, | |
| value=1.2, | |
| label="Repetition penalty" | |
| ) | |
| top_k = Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top k" | |
| ) | |
| top_p = Slider( | |
| minimum=0.01, | |
| maximum=0.99, | |
| value=0.95, | |
| label="Top p" | |
| ) | |
| do_sample = Checkbox( | |
| value=True, | |
| label="Do sample" | |
| ) | |
| temperature = Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| label="Temperature" | |
| ) | |
| with gr.Column(variant="panel"): | |
| output = Textbox( | |
| lines=11, | |
| label="El mestre diu...", | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| with gr.Row(variant="panel"): | |
| clear_btn = Button( | |
| "Clear", | |
| ) | |
| submit_btn = Button( | |
| "Submit", | |
| variant="primary", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=0.5): | |
| gr.Examples( | |
| label="Short prompts:", | |
| examples=[ | |
| ["""La capital de Suècia"""], | |
| ], | |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], | |
| outputs=output, | |
| fn=submit_input, | |
| ) | |
| gr.Examples( | |
| label="Zero-shot prompts", | |
| examples=[ | |
| ["Tradueix del Castellà al Català la següent frase: \"Eso es pan comido.\" \nTraducció:"], | |
| ], | |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], | |
| outputs=output, | |
| fn=submit_input, | |
| ) | |
| gr.Examples( | |
| label="Few-Shot prompts:", | |
| examples=[ | |
| ["""Oració: Els sons melòdics produeixen una sensació de calma i benestar en l'individu. \nParàfrasi: La música és molt relaxant i reconfortant.\n----\nOració: L'animal domèstic mostra una gran alegria i satisfacció. \nParàfrasi: El gos és molt feliç. \n----\nOració: El vehicle es va trencar i vaig haver de contactar amb el servei de remolc perquè el transportés. \nParàfrasi: El cotxe es va trencar i vaig haver de trucar la grua. \n----\nOració: El professor va explicar els conceptes de manera clara i concisa. \nParàfrasi:"""], | |
| ], | |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], | |
| outputs=output, | |
| fn=submit_input, | |
| ) | |
| input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False) | |
| input_.change(fn=None, inputs=[input_], api_name=False, js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """) | |
| input_.change(fn=None, inputs=[input_, placeholder_max_token], api_name=False, js="""(i, m) => { | |
| document.getElementById('inputlenght').textContent = i.length + ' ' | |
| document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; | |
| }""") | |
| clear_btn.click(fn=clear, inputs=[], outputs=[input_, output, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], queue=False, api_name=False) | |
| submit_btn.click(fn=submit_input, inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], outputs=[output], api_name="get-results") | |
| demo.launch(show_api=True) | |
| if __name__ == "__main__": | |
| gradio_app() |