Hej / app.py
WesanCZE's picture
Create app.py
7d9ab9e verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Název modelu na Hugging Face
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
# Inicializace tokenizeru
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Načtení modelu (s kvantizací pro snížení paměťových nároků)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True, # 8-bitová kvantizace pro úsporu paměti
)
# Vytvoření pipeline pro generování textu
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
def generate_response(prompt, max_length, temperature, top_p, top_k):
"""
Generuje odpověď na základě zadaného promptu a parametrů.
Parametry:
- prompt: vstupní text
- max_length: maximální délka generovaného textu
- temperature: teplota pro sampling (vyšší = kreativnější)
- top_p: parametr nucleus samplingu
- top_k: kolik nejvyšších pravděpodobností uvažovat při samplingu
"""
# Generování odpovědi
generation_kwargs = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.eos_token_id,
}
outputs = generator(prompt, **generation_kwargs)
generated_text = outputs[0]["generated_text"]
# Odstranění vstupního promptu z výstupu pro zobrazení pouze nového textu
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):]
return generated_text
# Definice Gradio rozhraní
with gr.Blocks() as demo:
gr.Markdown("# Mistral 7B Demo")
gr.Markdown("Zadejte text a model vygeneruje pokračování.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Vstupní text",
placeholder="Zadejte počáteční text...",
lines=5
)
with gr.Row():
with gr.Column():
max_length = gr.Slider(
minimum=10,
maximum=1024,
value=256,
step=1,
label="Maximální délka (tokeny)"
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.01,
label="Teplota"
)
with gr.Column():
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.01,
label="Top-p"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
submit_btn = gr.Button("Generovat")
with gr.Column():
output = gr.Textbox(
label="Vygenerovaný text",
lines=10
)
# Propojení tlačítka s funkcí
submit_btn.click(
fn=generate_response,
inputs=[prompt, max_length, temperature, top_p, top_k],
outputs=output
)
# Přidat příklady
gr.Examples(
examples=[
["Vítejte v Praze, hlavním městě České republiky.", 256, 0.7, 0.9, 50],
["Recept na tradiční český guláš:", 256, 0.7, 0.9, 50],
["Otázka: Jak funguje transformerový model?\nOdpověď:", 512, 0.7, 0.9, 50],
],
inputs=[prompt, max_length, temperature, top_p, top_k],
)
# Spuštění Gradio aplikace
demo.launch()