DINGDINGBELLS commited on
Commit
20736ba
·
verified ·
1 Parent(s): 0845cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -42
app.py CHANGED
@@ -1,24 +1,34 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  from threading import Thread
4
  import gradio as gr
5
 
6
  model_path = "."
7
 
8
- # 1. Загрузка модели и токенизатора
9
- print("🍌 Загрузка BananaGPT...")
10
  tokenizer = AutoTokenizer.from_pretrained(model_path)
11
  if tokenizer.pad_token is None:
12
  tokenizer.pad_token = tokenizer.eos_token
13
 
14
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True)
15
- model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
16
- print("✅ Готово!")
 
 
 
 
17
 
18
- def predict(message, history, temperature, top_p, rep_penalty, max_tokens):
19
- # Собираем промпт
 
 
 
 
 
 
 
20
  prompt = ""
21
- for user_msg, bot_msg in history:
22
  prompt += f"Юзер: {user_msg}\nБот: {bot_msg}\n"
23
  prompt += f"Юзер: {message}\nБот:"
24
 
@@ -26,50 +36,52 @@ def predict(message, history, temperature, top_p, rep_penalty, max_tokens):
26
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
27
 
28
  generation_kwargs = dict(
29
- **inputs,
 
30
  streamer=streamer,
31
- max_new_tokens=max_tokens,
32
  do_sample=True,
33
  top_p=top_p,
 
34
  temperature=temperature,
35
  repetition_penalty=rep_penalty,
36
- eos_token_id=tokenizer.eos_token_id,
 
37
  pad_token_id=tokenizer.pad_token_id,
 
38
  )
39
 
40
- # Запуск в отдельном потоке
41
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
42
- thread.start()
 
43
 
44
- generated_text = ""
45
- for new_text in streamer:
46
- generated_text += new_text
47
- # Наш стоп-кран для ролевика внутри
48
- if "Юзер:" in generated_text:
49
- generated_text = generated_text.split("Юзер:")[0].strip()
 
 
 
 
 
50
  yield generated_text
51
- break
52
- yield generated_text
53
 
54
- # Настройка интерфейса с ползунками
55
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="slate")) as demo:
56
- gr.Markdown("# 🍌 BananaGPT Lab")
57
-
58
- with gr.Row():
59
- with gr.Column(scale=4):
60
- # Чат интерфейс
61
- chat = gr.ChatInterface(
62
- fn=predict,
63
- additional_inputs=[
64
- gr.Slider(0.1, 1.5, value=0.34, label="Температура (Хаос)"),
65
- gr.Slider(0.1, 1.0, value=0.9, label="Top-p (Плотность)"),
66
- gr.Slider(1.0, 2.0, value=1.3, label="Repetition Penalty"),
67
- gr.Slider(16, 1024, value=512, step=16, label="Макс. токенов"),
68
- ]
69
- )
70
-
71
- gr.Markdown("---")
72
- gr.Markdown("ℹ️ *Если бот начал ролить за тебя, просто уменьши температуру или нажми Очистить.*")
73
 
74
  if __name__ == "__main__":
75
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
3
  from threading import Thread
4
  import gradio as gr
5
 
6
  model_path = "."
7
 
8
+ # 1. Загрузка модели (без квантования для стабильной скорости на vCPU)
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_path)
10
  if tokenizer.pad_token is None:
11
  tokenizer.pad_token = tokenizer.eos_token
12
 
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_path,
15
+ low_cpu_mem_usage=True,
16
+ use_cache=True,
17
+ device_map="cpu"
18
+ )
19
+ model.eval()
20
 
21
+ # 2. Фильтр-стоппер на "Юзер" (чтобы не роллила за тебя)
22
+ class StopOnUser(StoppingCriteria):
23
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
+ stop_words = ["Юзер", "User"]
25
+ decoded_tail = tokenizer.decode(input_ids[0][-5:]) # Проверяем последние 5 токенов
26
+ return any(sw in decoded_tail for sw in stop_words)
27
+
28
+ def predict(message, history, temperature, top_p, top_k, rep_penalty, no_repeat_ngram):
29
+ # Формируем контекст (последние 4 пары сообщений)
30
  prompt = ""
31
+ for user_msg, bot_msg in history[-4:]:
32
  prompt += f"Юзер: {user_msg}\nБот: {bot_msg}\n"
33
  prompt += f"Юзер: {message}\nБот:"
34
 
 
36
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
 
38
  generation_kwargs = dict(
39
+ input_ids=inputs["input_ids"],
40
+ attention_mask=inputs["attention_mask"],
41
  streamer=streamer,
42
+ max_new_tokens=512,
43
  do_sample=True,
44
  top_p=top_p,
45
+ top_k=int(top_k), # Твой фильтр на 70
46
  temperature=temperature,
47
  repetition_penalty=rep_penalty,
48
+ no_repeat_ngram_size=int(no_repeat_ngram), # Против шизы и циклов
49
+ stopping_criteria=StoppingCriteriaList([StopOnUser()]),
50
  pad_token_id=tokenizer.pad_token_id,
51
+ eos_token_id=tokenizer.eos_token_id,
52
  )
53
 
54
+ # Запуск генерации в потоке
55
+ with torch.inference_mode():
56
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
+ thread.start()
58
 
59
+ generated_text = ""
60
+ for new_text in streamer:
61
+ generated_text += new_text
62
+
63
+ # Если проскочил "Юзер", обрезаем и выходим
64
+ if "Юзер:" in generated_text or "User:" in generated_text:
65
+ for stop in ["Юзер:", "User:"]:
66
+ generated_text = generated_text.split(stop)[0]
67
+ yield generated_text.strip()
68
+ break
69
+
70
  yield generated_text
 
 
71
 
72
+ # 3. Интерфейс
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("## 🍌 BananaGPT: Режим Анти-Шиза")
75
+ chat = gr.ChatInterface(
76
+ fn=predict,
77
+ additional_inputs=[
78
+ gr.Slider(0.1, 1.0, value=0.34, label="Temperature"),
79
+ gr.Slider(0.1, 1.0, value=0.9, label="Top-P"),
80
+ gr.Slider(1, 100, value=70, step=1, label="Top-K (Отсечка мусора)"),
81
+ gr.Slider(1.0, 2.0, value=1.2, label="Repetition Penalty"),
82
+ gr.Slider(0, 10, value=3, step=1, label="No Repeat N-Gram (Запрет циклов)"),
83
+ ]
84
+ )
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  demo.launch(server_name="0.0.0.0", server_port=7860)