| | |
| |
|
| | import spaces |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| | import gradio as gr |
| | import torch |
| | from transformers.utils import logging |
| | from example_queries import small_query, long_query |
| |
|
| | logging.set_verbosity_info() |
| | logger = logging.get_logger("transformers") |
| |
|
| | model_name='t5-small' |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) |
| |
|
| | ft_model_name="daljeetsingh/sql_ft_t5small_kag" |
| | ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16) |
| |
|
| | original_model.to('cuda') |
| | ft_model.to('cuda') |
| |
|
| | @spaces.GPU |
| | def translate_text(text): |
| | prompt = f"{text}" |
| | inputs = tokenizer(prompt, return_tensors='pt') |
| | inputs = inputs.to('cuda') |
| |
|
| | try: |
| | output = tokenizer.decode( |
| | original_model.generate( |
| | inputs["input_ids"], |
| | max_new_tokens=200, |
| | )[0], |
| | skip_special_tokens=True |
| | ) |
| | ft_output = tokenizer.decode( |
| | ft_model.generate( |
| | inputs["input_ids"], |
| | max_new_tokens=200, |
| | )[0], |
| | skip_special_tokens=True |
| | ) |
| | return [output, ft_output] |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(): |
| | prompt = gr.Textbox( |
| | value=small_query, |
| | lines=8, |
| | placeholder="Enter prompt...", |
| | label="Prompt" |
| | ) |
| | submit_btn = gr.Button(value="Generate") |
| | with gr.Column(): |
| | orig_output = gr.Textbox(label="OriginalModel", lines=2) |
| | ft_output = gr.Textbox(label="FTModel", lines=8) |
| |
|
| | submit_btn.click( |
| | translate_text, inputs=[prompt], outputs=[orig_output, ft_output], api_name=False |
| | ) |
| | examples = gr.Examples( |
| | examples=[ |
| | [small_query], |
| | [long_query], |
| | ], |
| | inputs=[prompt], |
| | ) |
| |
|
| | demo.launch(show_api=False, share=True, debug=True) |
| |
|
| |
|