Files changed (1) hide show
  1. app.py +14 -123
app.py CHANGED
@@ -1,128 +1,19 @@
1
- import os
2
- from collections.abc import Iterator
3
- from threading import Thread
4
-
5
  import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- DESCRIPTION = """\
11
- # Llama 3.2 3B Instruct
12
- Llama 3.2 3B is Meta's latest iteration of open LLMs.
13
- This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
14
- For more details, please check [our post](https://huggingface.co/blog/llama32).
15
- """
16
-
17
- MAX_MAX_NEW_TOKENS = 2048
18
- DEFAULT_MAX_NEW_TOKENS = 1024
19
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
-
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
-
23
- model_id = "meta-llama/Llama-3.2-3B-Instruct"
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- )
30
- model.eval()
31
-
32
-
33
- @spaces.GPU(duration=90)
34
- def generate(
35
- message: str,
36
- chat_history: list[dict],
37
- max_new_tokens: int = 1024,
38
- temperature: float = 0.6,
39
- top_p: float = 0.9,
40
- top_k: int = 50,
41
- repetition_penalty: float = 1.2,
42
- ) -> Iterator[str]:
43
- conversation = [*chat_history, {"role": "user", "content": message}]
44
-
45
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
46
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
- input_ids = input_ids.to(model.device)
50
-
51
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
- generate_kwargs = dict(
53
- {"input_ids": input_ids},
54
- streamer=streamer,
55
- max_new_tokens=max_new_tokens,
56
- do_sample=True,
57
- top_p=top_p,
58
- top_k=top_k,
59
- temperature=temperature,
60
- num_beams=1,
61
- repetition_penalty=repetition_penalty,
62
- )
63
- t = Thread(target=model.generate, kwargs=generate_kwargs)
64
- t.start()
65
-
66
- outputs = []
67
- for text in streamer:
68
- outputs.append(text)
69
- yield "".join(outputs)
70
 
 
 
71
 
72
- demo = gr.ChatInterface(
73
- fn=generate,
74
- additional_inputs=[
75
- gr.Slider(
76
- label="Max new tokens",
77
- minimum=1,
78
- maximum=MAX_MAX_NEW_TOKENS,
79
- step=1,
80
- value=DEFAULT_MAX_NEW_TOKENS,
81
- ),
82
- gr.Slider(
83
- label="Temperature",
84
- minimum=0.1,
85
- maximum=4.0,
86
- step=0.1,
87
- value=0.6,
88
- ),
89
- gr.Slider(
90
- label="Top-p (nucleus sampling)",
91
- minimum=0.05,
92
- maximum=1.0,
93
- step=0.05,
94
- value=0.9,
95
- ),
96
- gr.Slider(
97
- label="Top-k",
98
- minimum=1,
99
- maximum=1000,
100
- step=1,
101
- value=50,
102
- ),
103
- gr.Slider(
104
- label="Repetition penalty",
105
- minimum=1.0,
106
- maximum=2.0,
107
- step=0.05,
108
- value=1.2,
109
- ),
110
- ],
111
- stop_btn=None,
112
- examples=[
113
- ["Hello there! How are you doing?"],
114
- ["Can you explain briefly to me what is the Python programming language?"],
115
- ["Explain the plot of Cinderella in a sentence."],
116
- ["How many hours does it take a man to eat a Helicopter?"],
117
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
118
- ],
119
- cache_examples=False,
120
- type="messages",
121
- description=DESCRIPTION,
122
- css_paths="style.css",
123
- fill_height=True,
124
- )
125
 
 
 
 
 
 
 
126
 
127
- if __name__ == "__main__":
128
- demo.queue(max_size=20).launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # Načteme malý model pro odpovídání
5
+ chatbot = pipeline("text-generation", model="microsoft/phi-2")
6
 
7
+ def chat(user_input):
8
+ response = chatbot(user_input, max_length=200, do_sample=True, temperature=0.7)
9
+ return response[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Uděláme jednoduché webové rozhraní
12
+ interface = gr.Interface(fn=chat,
13
+ inputs="text",
14
+ outputs="text",
15
+ title="Tvůj AI učitel angličtiny",
16
+ description="Napiš otázku anglicky a AI ti odpoví.")
17
 
18
+ # Spustíme aplikaci
19
+ interface.launch()