| | import gradio as gr |
| | import torch |
| | import pytorch_lightning as pl |
| | from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| | from transformers import ( |
| | MT5ForConditionalGeneration, |
| | MT5TokenizerFast, |
| | ) |
| |
|
| | model = MT5ForConditionalGeneration.from_pretrained( |
| | "minjibi/qa", |
| | return_dict=True, |
| | ) |
| | tokenizer = MT5TokenizerFast.from_pretrained( |
| | "minjibi/qa" |
| | ) |
| |
|
| | def predict(text): |
| | input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) |
| | generated_ids = model.generate( |
| | input_ids=input_ids, |
| | num_beams=5, |
| | max_length=1000, |
| | repetition_penalty=3.0, |
| | length_penalty=1.0, |
| | early_stopping=True, |
| | top_p=50, |
| | top_k=20, |
| | num_return_sequences=3, |
| | ) |
| | |
| | preds = [ |
| | tokenizer.decode( |
| | g, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=True, |
| | ) |
| | for g in generated_ids |
| | ] |
| | |
| | output = [text.replace('A', 'Answer') for text in preds] |
| | |
| | final_str = '\n'.join([f"{i+1}. Question: {s.split('Answer')[0].strip()}\n Answer{s.split('Answer')[1].strip()}" for i, s in enumerate(output)]) |
| | |
| | return final_str |
| | examples = [ |
| | ["ไพธอนเป็นภาษาการเขียนโปรแกรมที่มีการตีความระดับสูง ภาษาคอมพิวเตอร์นี้สร้างโดย Guido van Rossum และเปิดตัวครั้งแรกในปี 1991"], |
| | ["แมว ชื่อวิทยาศาสตร์ (Felis catus) เป็นสปีชีส์สัตว์เลี้ยงของสัตว์เลี้ยงลูกด้วยนมกินเนื้อขนาดเล็ก โดยเป็นแมวสปีชีส์เดียวในวงศ์ Felidae ที่ถูกปรับเป็นสัตว์เลี้ยง และมักเรียกเป็น แมวบ้าน เพื่อแยกมันจากสมาชิกที่อยู่ในป่า"], |
| | ] |
| |
|
| |
|
| | iface = gr.Interface(fn=predict, inputs="text", outputs="text", examples=examples) |
| | iface.launch() |