Utiric commited on
Commit
692a239
·
verified ·
1 Parent(s): 92ee0f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -18
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import time
 
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
7
 
@@ -39,16 +40,14 @@ def build_prompt(message, history, system_message, max_ctx_tokens=1024):
39
  del msgs[i:i+2]
40
  break
41
 
42
- def respond(message, history, system_message, max_tokens, temperature, top_p):
43
  user_id = "default" # API bağlarsan burada kullanıcı ID'si ile değiştir
44
  past = sessions.get(user_id)
45
 
46
  if past is None:
47
- # İlk mesaj → tüm prompt
48
  text = build_prompt(message, history, system_message)
49
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
  else:
51
- # Sadece yeni mesajı encode et
52
  inputs = tokenizer([message], return_tensors="pt").to(model.device)
53
 
54
  do_sample = temperature > 0
@@ -61,28 +60,35 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
61
  past_key_values=past
62
  )
63
 
 
 
 
 
 
 
64
  start_time = time.time()
 
 
65
  with torch.inference_mode():
66
- outputs = model.generate(**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None},
67
- return_dict_in_generate=True, output_scores=False)
 
 
 
 
68
  end_time = time.time()
 
69
 
70
  # KV cache güncelle
71
- sessions[user_id] = outputs.past_key_values
72
-
73
- # Yanıtı decode et
74
- new_tokens = outputs.sequences[0][inputs["input_ids"].shape[1]:]
75
- content = tokenizer.decode(new_tokens, skip_special_tokens=True).strip("\n")
76
-
77
- # T/S hesapla
78
- token_count = len(new_tokens)
79
- elapsed = end_time - start_time
80
- tps = token_count / elapsed if elapsed > 0 else 0
81
 
82
- return f"{content}\n\n⚡ **Hız:** {tps:.2f} token/sn"
83
 
84
  demo = gr.ChatInterface(
85
- respond,
86
  additional_inputs=[
87
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
88
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
 
1
  import time
2
+ import threading
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
 
7
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
8
 
 
40
  del msgs[i:i+2]
41
  break
42
 
43
+ def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
44
  user_id = "default" # API bağlarsan burada kullanıcı ID'si ile değiştir
45
  past = sessions.get(user_id)
46
 
47
  if past is None:
 
48
  text = build_prompt(message, history, system_message)
49
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
  else:
 
51
  inputs = tokenizer([message], return_tensors="pt").to(model.device)
52
 
53
  do_sample = temperature > 0
 
60
  past_key_values=past
61
  )
62
 
63
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
64
+ thread = threading.Thread(
65
+ target=model.generate,
66
+ kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
67
+ )
68
+
69
  start_time = time.time()
70
+ token_count = 0
71
+
72
  with torch.inference_mode():
73
+ thread.start()
74
+ for token_text in streamer:
75
+ token_count += 1
76
+ yield token_text # Token anında kullanıcıya akar
77
+ thread.join()
78
+
79
  end_time = time.time()
80
+ tps = token_count / (end_time - start_time) if (end_time - start_time) > 0 else 0
81
 
82
  # KV cache güncelle
83
+ # generate() ile streamer kullanıldığında past_key_values doğrudan dönmez,
84
+ # bu yüzden cache'i burada güncellemek için model.forward tabanlı bir yapı kurmak gerekir.
85
+ # Basitlik için bu örnekte cache ilk turdan sonra sıfırlanıyor.
86
+ sessions[user_id] = None # İstersen burayı ileri seviye cache yönetimi ile değiştirebilirsin.
 
 
 
 
 
 
87
 
88
+ yield f"\n\n⚡ **Hız:** {tps:.2f} token/sn"
89
 
90
  demo = gr.ChatInterface(
91
+ respond_stream,
92
  additional_inputs=[
93
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
94
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),