Samuel4677 commited on
Commit
e995f26
·
verified ·
1 Parent(s): bc2c277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -27
app.py CHANGED
@@ -1,45 +1,71 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
 
 
 
 
 
4
 
5
- # Wczytanie modelu
6
- model_name = "radlab/polish-gpt2-medium-v2" # lepszy model GPT2
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Funkcja odpowiedzi bota
11
- def chatbot(user_input, history=""):
 
 
 
12
  history += f"Użytkownik: {user_input}\nAI:"
13
  input_ids = tokenizer.encode(history, return_tensors="pt", truncation=True, max_length=1024)
14
 
15
- output_ids = model.generate(
 
 
 
16
  input_ids,
17
- max_length=input_ids.shape[1] + 100,
18
  pad_token_id=tokenizer.eos_token_id,
19
  do_sample=True,
20
- top_k=50,
21
- top_p=0.92,
22
- temperature=0.6,
23
- repetition_penalty=1.1
24
  )
25
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
26
- bot_reply = output_text[len(history):].split("Użytkownik:")[0].strip()
27
- new_history = history + f" {bot_reply}\n"
28
- return bot_reply, new_history
29
-
30
- # Interfejs
31
- with gr.Blocks() as demo:
32
- gr.Markdown("# 🤖 Polski Chatbot AI")
33
- chatbot_output = gr.Textbox(label="Odpowiedź AI")
 
 
34
  user_input = gr.Textbox(label="Wpisz wiadomość")
35
- state = gr.State("")
36
 
37
- def respond(user_input, state):
38
- reply, updated = chatbot(user_input, state)
39
- return reply, updated
 
 
40
 
41
  send_btn = gr.Button("Wyślij")
42
- send_btn.click(respond, inputs=[user_input, state], outputs=[chatbot_output, state])
43
 
44
- # Uruchom serwer
45
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # Optymalizacja Torch
6
+ torch.set_float32_matmul_precision('high')
7
+
8
+ # Ustawienia modelu
9
+ model_name = "mrm8488/distilgpt2-finetuned-text-generation" # Można zmienić na polski model, jeśli jest zoptymalizowany
10
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
+ # Maksymalna długość historii w tokenach
15
+ MAX_HISTORY = 800
16
+
17
+ # Funkcja czatu
18
+ def chatbot_response(user_input, history, top_k, top_p, temperature):
19
  history += f"Użytkownik: {user_input}\nAI:"
20
  input_ids = tokenizer.encode(history, return_tensors="pt", truncation=True, max_length=1024)
21
 
22
+ if input_ids.shape[1] > MAX_HISTORY:
23
+ input_ids = input_ids[:, -MAX_HISTORY:]
24
+
25
+ output = model.generate(
26
  input_ids,
27
+ max_length=input_ids.shape[1] + 80,
28
  pad_token_id=tokenizer.eos_token_id,
29
  do_sample=True,
30
+ top_k=int(top_k),
31
+ top_p=top_p,
32
+ temperature=temperature
 
33
  )
34
+
35
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
36
+ model_reply = output_text[len(history):].split("Użytkownik:")[0].strip()
37
+ history += f" {model_reply}\n"
38
+ return history, history
39
+
40
+ # Gradio interfejs
41
+ with gr.Blocks(title="Polski Chatbot AI") as demo:
42
+ gr.Markdown("# 🤖 Polski Chatbot AI\nModel: distilgpt2-finetuned-text-generation")
43
+
44
+ chat_output = gr.Textbox(label="Historia rozmowy", lines=15, interactive=False)
45
  user_input = gr.Textbox(label="Wpisz wiadomość")
 
46
 
47
+ top_k = gr.Slider(0, 100, value=50, step=1, label="Top-k")
48
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
49
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
50
+
51
+ history_state = gr.State("")
52
 
53
  send_btn = gr.Button("Wyślij")
 
54
 
55
+ send_btn.click(
56
+ chatbot_response,
57
+ inputs=[user_input, history_state, top_k, top_p, temperature],
58
+ outputs=[chat_output, history_state]
59
+ )
60
+
61
+ clear_btn = gr.Button("🧹 Wyczyść historię")
62
+ clear_btn.click(lambda: ("", ""), outputs=[chat_output, history_state])
63
+
64
+ gr.Markdown("\n## 🔄 Szybkie pytania:")
65
+ with gr.Row():
66
+ gr.Button("Jak się nazywasz?").click(fn=lambda: chatbot_response("Jak się nazywasz?", history_state.value, top_k.value, top_p.value, temperature.value), outputs=[chat_output, history_state])
67
+ gr.Button("Czym się zajmujesz?").click(fn=lambda: chatbot_response("Czym się zajmujesz?", history_state.value, top_k.value, top_p.value, temperature.value), outputs=[chat_output, history_state])
68
+
69
+ # Uruchom
70
+ if __name__ == "__main__":
71
+ demo.launch()