Samuel4677 commited on
Commit
66d4184
·
verified ·
1 Parent(s): e995f26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -1,24 +1,20 @@
1
  import torch
2
- import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
- # Optymalizacja Torch
6
  torch.set_float32_matmul_precision('high')
7
 
8
- # Ustawienia modelu
9
- model_name = "mrm8488/distilgpt2-finetuned-text-generation" # Można zmienić na polski model, jeśli jest zoptymalizowany
10
-
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
- # Maksymalna długość historii w tokenach
15
- MAX_HISTORY = 800
16
 
17
- # Funkcja czatu
18
  def chatbot_response(user_input, history, top_k, top_p, temperature):
19
  history += f"Użytkownik: {user_input}\nAI:"
20
  input_ids = tokenizer.encode(history, return_tensors="pt", truncation=True, max_length=1024)
21
-
22
  if input_ids.shape[1] > MAX_HISTORY:
23
  input_ids = input_ids[:, -MAX_HISTORY:]
24
 
@@ -32,14 +28,13 @@ def chatbot_response(user_input, history, top_k, top_p, temperature):
32
  temperature=temperature
33
  )
34
 
35
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
36
- model_reply = output_text[len(history):].split("Użytkownik:")[0].strip()
37
- history += f" {model_reply}\n"
38
  return history, history
39
 
40
- # Gradio interfejs
41
- with gr.Blocks(title="Polski Chatbot AI") as demo:
42
- gr.Markdown("# 🤖 Polski Chatbot AI\nModel: distilgpt2-finetuned-text-generation")
43
 
44
  chat_output = gr.Textbox(label="Historia rozmowy", lines=15, interactive=False)
45
  user_input = gr.Textbox(label="Wpisz wiadomość")
@@ -51,7 +46,6 @@ with gr.Blocks(title="Polski Chatbot AI") as demo:
51
  history_state = gr.State("")
52
 
53
  send_btn = gr.Button("Wyślij")
54
-
55
  send_btn.click(
56
  chatbot_response,
57
  inputs=[user_input, history_state, top_k, top_p, temperature],
@@ -61,11 +55,12 @@ with gr.Blocks(title="Polski Chatbot AI") as demo:
61
  clear_btn = gr.Button("🧹 Wyczyść historię")
62
  clear_btn.click(lambda: ("", ""), outputs=[chat_output, history_state])
63
 
64
- gr.Markdown("\n## 🔄 Szybkie pytania:")
65
  with gr.Row():
66
- gr.Button("Jak się nazywasz?").click(fn=lambda: chatbot_response("Jak się nazywasz?", history_state.value, top_k.value, top_p.value, temperature.value), outputs=[chat_output, history_state])
67
- gr.Button("Czym się zajmujesz?").click(fn=lambda: chatbot_response("Czym się zajmujesz?", history_state.value, top_k.value, top_p.value, temperature.value), outputs=[chat_output, history_state])
 
 
68
 
69
- # Uruchom
70
  if __name__ == "__main__":
71
  demo.launch()
 
1
  import torch
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import gradio as gr
4
 
5
+ # Optymalizacja obliczeń
6
  torch.set_float32_matmul_precision('high')
7
 
8
+ # Nowy, lżejszy model
9
+ model_name = "distilbert/distilgpt2"
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
 
13
+ MAX_HISTORY = 800 # limit tokenów w historii
 
14
 
 
15
  def chatbot_response(user_input, history, top_k, top_p, temperature):
16
  history += f"Użytkownik: {user_input}\nAI:"
17
  input_ids = tokenizer.encode(history, return_tensors="pt", truncation=True, max_length=1024)
 
18
  if input_ids.shape[1] > MAX_HISTORY:
19
  input_ids = input_ids[:, -MAX_HISTORY:]
20
 
 
28
  temperature=temperature
29
  )
30
 
31
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
32
+ reply = decoded[len(history):].split("Użytkownik:")[0].strip()
33
+ history += f" {reply}\n"
34
  return history, history
35
 
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("# 🤖 Polski Chatbot AI (DistilGPT2)")
 
38
 
39
  chat_output = gr.Textbox(label="Historia rozmowy", lines=15, interactive=False)
40
  user_input = gr.Textbox(label="Wpisz wiadomość")
 
46
  history_state = gr.State("")
47
 
48
  send_btn = gr.Button("Wyślij")
 
49
  send_btn.click(
50
  chatbot_response,
51
  inputs=[user_input, history_state, top_k, top_p, temperature],
 
55
  clear_btn = gr.Button("🧹 Wyczyść historię")
56
  clear_btn.click(lambda: ("", ""), outputs=[chat_output, history_state])
57
 
58
+ gr.Markdown("## 🔄 Szybkie pytania:")
59
  with gr.Row():
60
+ gr.Button("Jak się nazywasz?").click(
61
+ lambda _: chatbot_response("Jak się nazywasz?", "", 50, 0.9, 0.7), outputs=[chat_output, history_state])
62
+ gr.Button("Czym się zajmujesz?").click(
63
+ lambda _: chatbot_response("Czym się zajmujesz?", "", 50, 0.9, 0.7), outputs=[chat_output, history_state])
64
 
 
65
  if __name__ == "__main__":
66
  demo.launch()