Utiric commited on
Commit
4bdd945
·
verified ·
1 Parent(s): 363f32b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -31
app.py CHANGED
@@ -7,8 +7,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
7
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
8
 
9
  # CPU optimizasyonları
10
- torch.set_num_threads(torch.get_num_threads()) # tüm çekirdekleri kullan
11
- torch.set_float32_matmul_precision("high") # matmul hızını artır
12
 
13
  # Model/Tokenizer global yükleme
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -19,36 +19,34 @@ model = AutoModelForCausalLM.from_pretrained(
19
  )
20
  model.eval()
21
 
22
- # Kullanıcı bazlı KV cache
23
- sessions = {} # {user_id: past_key_values}
24
-
25
  def build_prompt(message, history, system_message, max_ctx_tokens=1024):
26
  msgs = [{"role": "system", "content": system_message}]
27
  for u, a in history:
28
- if u: msgs.append({"role": "user", "content": u})
29
- if a: msgs.append({"role": "assistant", "content": a})
 
 
30
  msgs.append({"role": "user", "content": message})
31
 
32
  # Token bütçesi ile kırpma
33
  while True:
34
- text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
 
 
35
  if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
36
  return text
37
- # En eski user+assistant çiftini at
38
  for i in range(1, len(msgs)):
39
  if msgs[i]["role"] != "system":
40
  del msgs[i:i+2]
41
  break
42
 
43
  def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
44
- user_id = "default"
45
- past = sessions.get(user_id)
46
-
47
- if past is None:
48
- text = build_prompt(message, history, system_message)
49
- inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
- else:
51
- inputs = tokenizer([message], return_tensors="pt").to(model.device)
52
 
53
  do_sample = temperature > 0
54
  gen_kwargs = dict(
@@ -56,30 +54,43 @@ def respond_stream(message, history, system_message, max_tokens, temperature, to
56
  do_sample=do_sample,
57
  top_p=top_p,
58
  temperature=temperature if do_sample else None,
59
- use_cache=True,
60
- past_key_values=past
 
61
  )
62
 
63
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
 
 
 
 
 
64
  thread = threading.Thread(
65
  target=model.generate,
66
  kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
67
  )
68
 
69
- input_len = inputs["input_ids"].shape[1]
70
  partial_text = ""
71
- token_count = 0
72
- start_time = time.time()
73
-
74
  with torch.inference_mode():
75
  thread.start()
76
- for new_text in streamer:
77
- partial_text += new_text
78
- yield partial_text
 
 
79
  thread.join()
80
 
81
- end_time = time.time()
82
- tps = token_count / (end_time - start_time) if (end_time - start_time) > 0 else 0
 
 
 
 
 
83
  yield partial_text + f"\n\n⚡ **Hız:** {tps:.2f} token/sn"
84
 
85
  demo = gr.ChatInterface(
@@ -87,10 +98,11 @@ demo = gr.ChatInterface(
87
  additional_inputs=[
88
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
89
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
90
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
91
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
92
  ],
93
  )
94
 
95
  if __name__ == "__main__":
96
- demo.launch()
 
 
7
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
8
 
9
  # CPU optimizasyonları
10
+ torch.set_num_threads(torch.get_num_threads()) # Tüm çekirdekleri kullan
11
+ torch.set_float32_matmul_precision("high") # Matmul hızını artır
12
 
13
  # Model/Tokenizer global yükleme
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
19
  )
20
  model.eval()
21
 
 
 
 
22
  def build_prompt(message, history, system_message, max_ctx_tokens=1024):
23
  msgs = [{"role": "system", "content": system_message}]
24
  for u, a in history:
25
+ if u:
26
+ msgs.append({"role": "user", "content": u})
27
+ if a:
28
+ msgs.append({"role": "assistant", "content": a})
29
  msgs.append({"role": "user", "content": message})
30
 
31
  # Token bütçesi ile kırpma
32
  while True:
33
+ text = tokenizer.apply_chat_template(
34
+ msgs, tokenize=False, add_generation_prompt=True
35
+ )
36
  if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
37
  return text
38
+ # En eski user+assistant çiftini at (system'i koru)
39
  for i in range(1, len(msgs)):
40
  if msgs[i]["role"] != "system":
41
  del msgs[i:i+2]
42
  break
43
 
44
  def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
45
+ # İlk mesajda tüm prompt'u veriyoruz; sonraki turlarda da bu örnek basit tutularak aynı akış korunuyor.
46
+ # (HF TextIteratorStreamer ile generate() sonrası past_key_values dışarı alınmadığı için
47
+ # bu sürüm KV cache’i oturumlar arası taşımıyor; hız için streaming + bağlam kırpma kullanıyoruz.)
48
+ text = build_prompt(message, history, system_message)
49
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
50
 
51
  do_sample = temperature > 0
52
  gen_kwargs = dict(
 
54
  do_sample=do_sample,
55
  top_p=top_p,
56
  temperature=temperature if do_sample else None,
57
+ use_cache=True, # decode aşamasında KV cache'i etkin
58
+ eos_token_id=tokenizer.eos_token_id,
59
+ pad_token_id=tokenizer.eos_token_id,
60
  )
61
 
62
+ # skip_prompt=True ile prompt’un ekrana yazılmasını engelleriz (Transformers >= 4.42 gerektirir)
63
+ try:
64
+ streamer = TextIteratorStreamer(
65
+ tokenizer, skip_special_tokens=True, skip_prompt=True
66
+ )
67
+ except TypeError:
68
+ # Eski sürüm uyumluluğu: skip_prompt yoksa, yine de çalışır ama ilk chunk'ta prompt kırıntısı gelebilir
69
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
70
+
71
  thread = threading.Thread(
72
  target=model.generate,
73
  kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
74
  )
75
 
 
76
  partial_text = ""
77
+ start_time = None # İlk token geldiği anı işaretler
 
 
78
  with torch.inference_mode():
79
  thread.start()
80
+ for chunk in streamer:
81
+ if start_time is None:
82
+ start_time = time.time()
83
+ partial_text += chunk
84
+ yield partial_text # append streaming: önceki + yeni chunk
85
  thread.join()
86
 
87
+ end_time = time.time() if start_time is not None else time.time()
88
+
89
+ # Üretilen token sayısını final metinden hesapla
90
+ gen_token_count = len(tokenizer(partial_text, add_special_tokens=False).input_ids)
91
+ duration = max(1e-6, end_time - start_time) if start_time else 0.0
92
+ tps = (gen_token_count / duration) if duration > 0 else 0.0
93
+
94
  yield partial_text + f"\n\n⚡ **Hız:** {tps:.2f} token/sn"
95
 
96
  demo = gr.ChatInterface(
 
98
  additional_inputs=[
99
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
100
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
101
+ gr.Slider(minimum=0.0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
102
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
103
  ],
104
  )
105
 
106
  if __name__ == "__main__":
107
+ # Gradio’nun stream buffer hatalarını azaltmak için queue iyi sonuç verir
108
+ demo.queue().launch()