WesanCZE commited on
Commit
7d9ab9e
·
verified ·
1 Parent(s): f2f7b4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+
5
+ # Název modelu na Hugging Face
6
+ MODEL_NAME = "mistralai/Mistral-7B-v0.1"
7
+
8
+ # Inicializace tokenizeru
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+
11
+ # Načtení modelu (s kvantizací pro snížení paměťových nároků)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_NAME,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto",
16
+ load_in_8bit=True, # 8-bitová kvantizace pro úsporu paměti
17
+ )
18
+
19
+ # Vytvoření pipeline pro generování textu
20
+ generator = pipeline(
21
+ "text-generation",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ )
25
+
26
+ def generate_response(prompt, max_length, temperature, top_p, top_k):
27
+ """
28
+ Generuje odpověď na základě zadaného promptu a parametrů.
29
+
30
+ Parametry:
31
+ - prompt: vstupní text
32
+ - max_length: maximální délka generovaného textu
33
+ - temperature: teplota pro sampling (vyšší = kreativnější)
34
+ - top_p: parametr nucleus samplingu
35
+ - top_k: kolik nejvyšších pravděpodobností uvažovat při samplingu
36
+ """
37
+ # Generování odpovědi
38
+ generation_kwargs = {
39
+ "max_new_tokens": max_length,
40
+ "temperature": temperature,
41
+ "top_p": top_p,
42
+ "top_k": top_k,
43
+ "do_sample": temperature > 0,
44
+ "pad_token_id": tokenizer.eos_token_id,
45
+ }
46
+
47
+ outputs = generator(prompt, **generation_kwargs)
48
+ generated_text = outputs[0]["generated_text"]
49
+
50
+ # Odstranění vstupního promptu z výstupu pro zobrazení pouze nového textu
51
+ if generated_text.startswith(prompt):
52
+ generated_text = generated_text[len(prompt):]
53
+
54
+ return generated_text
55
+
56
+ # Definice Gradio rozhraní
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown("# Mistral 7B Demo")
59
+ gr.Markdown("Zadejte text a model vygeneruje pokračování.")
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ prompt = gr.Textbox(
64
+ label="Vstupní text",
65
+ placeholder="Zadejte počáteční text...",
66
+ lines=5
67
+ )
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ max_length = gr.Slider(
72
+ minimum=10,
73
+ maximum=1024,
74
+ value=256,
75
+ step=1,
76
+ label="Maximální délka (tokeny)"
77
+ )
78
+ temperature = gr.Slider(
79
+ minimum=0.0,
80
+ maximum=2.0,
81
+ value=0.7,
82
+ step=0.01,
83
+ label="Teplota"
84
+ )
85
+ with gr.Column():
86
+ top_p = gr.Slider(
87
+ minimum=0.0,
88
+ maximum=1.0,
89
+ value=0.9,
90
+ step=0.01,
91
+ label="Top-p"
92
+ )
93
+ top_k = gr.Slider(
94
+ minimum=1,
95
+ maximum=100,
96
+ value=50,
97
+ step=1,
98
+ label="Top-k"
99
+ )
100
+
101
+ submit_btn = gr.Button("Generovat")
102
+
103
+ with gr.Column():
104
+ output = gr.Textbox(
105
+ label="Vygenerovaný text",
106
+ lines=10
107
+ )
108
+
109
+ # Propojení tlačítka s funkcí
110
+ submit_btn.click(
111
+ fn=generate_response,
112
+ inputs=[prompt, max_length, temperature, top_p, top_k],
113
+ outputs=output
114
+ )
115
+
116
+ # Přidat příklady
117
+ gr.Examples(
118
+ examples=[
119
+ ["Vítejte v Praze, hlavním městě České republiky.", 256, 0.7, 0.9, 50],
120
+ ["Recept na tradiční český guláš:", 256, 0.7, 0.9, 50],
121
+ ["Otázka: Jak funguje transformerový model?\nOdpověď:", 512, 0.7, 0.9, 50],
122
+ ],
123
+ inputs=[prompt, max_length, temperature, top_p, top_k],
124
+ )
125
+
126
+ # Spuštění Gradio aplikace
127
+ demo.launch()