|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
|
|
|
MODEL_NAME = "mistralai/Mistral-7B-v0.1" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
load_in_8bit=True, |
|
|
) |
|
|
|
|
|
|
|
|
generator = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
def generate_response(prompt, max_length, temperature, top_p, top_k): |
|
|
""" |
|
|
Generuje odpověď na základě zadaného promptu a parametrů. |
|
|
|
|
|
Parametry: |
|
|
- prompt: vstupní text |
|
|
- max_length: maximální délka generovaného textu |
|
|
- temperature: teplota pro sampling (vyšší = kreativnější) |
|
|
- top_p: parametr nucleus samplingu |
|
|
- top_k: kolik nejvyšších pravděpodobností uvažovat při samplingu |
|
|
""" |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_length, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"top_k": top_k, |
|
|
"do_sample": temperature > 0, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
} |
|
|
|
|
|
outputs = generator(prompt, **generation_kwargs) |
|
|
generated_text = outputs[0]["generated_text"] |
|
|
|
|
|
|
|
|
if generated_text.startswith(prompt): |
|
|
generated_text = generated_text[len(prompt):] |
|
|
|
|
|
return generated_text |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Mistral 7B Demo") |
|
|
gr.Markdown("Zadejte text a model vygeneruje pokračování.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Vstupní text", |
|
|
placeholder="Zadejte počáteční text...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
max_length = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=1024, |
|
|
value=256, |
|
|
step=1, |
|
|
label="Maximální délka (tokeny)" |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.01, |
|
|
label="Teplota" |
|
|
) |
|
|
with gr.Column(): |
|
|
top_p = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.01, |
|
|
label="Top-p" |
|
|
) |
|
|
top_k = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=1, |
|
|
label="Top-k" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Generovat") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Textbox( |
|
|
label="Vygenerovaný text", |
|
|
lines=10 |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate_response, |
|
|
inputs=[prompt, max_length, temperature, top_p, top_k], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Vítejte v Praze, hlavním městě České republiky.", 256, 0.7, 0.9, 50], |
|
|
["Recept na tradiční český guláš:", 256, 0.7, 0.9, 50], |
|
|
["Otázka: Jak funguje transformerový model?\nOdpověď:", 512, 0.7, 0.9, 50], |
|
|
], |
|
|
inputs=[prompt, max_length, temperature, top_p, top_k], |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |