Utiric commited on
Commit
5cecbb0
·
verified ·
1 Parent(s): 54ba978

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -3,10 +3,13 @@ import time
3
  import threading
4
  import torch
5
  import gradio as gr
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
 
8
- MODEL_NAME = "daniel-dona/gemma-3-270m-it"
 
9
 
 
10
  os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
11
  os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
12
  os.environ.setdefault("OMP_PROC_BIND", "TRUE")
@@ -15,9 +18,30 @@ torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
15
  torch.set_num_interop_threads(1)
16
  torch.set_float32_matmul_precision("high")
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
 
21
  torch_dtype=torch.float32,
22
  device_map=None
23
  )
@@ -51,14 +75,18 @@ def respond_stream(message, history, system_message, max_tokens, temperature, to
51
  temperature=temperature if do_sample else None,
52
  use_cache=True,
53
  eos_token_id=tokenizer.eos_token_id,
54
- pad_token_id=tokenizer.eos_token_id,
55
  )
56
  try:
57
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
58
  except TypeError:
59
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
60
- thread = threading.Thread(target=model.generate, kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer})
 
 
 
61
  partial_text = ""
 
62
  start_time = None
63
  with torch.inference_mode():
64
  thread.start()
@@ -67,13 +95,13 @@ def respond_stream(message, history, system_message, max_tokens, temperature, to
67
  if start_time is None:
68
  start_time = time.time()
69
  partial_text += chunk
 
70
  yield partial_text
71
  finally:
72
  thread.join()
73
  end_time = time.time() if start_time is not None else time.time()
74
  duration = max(1e-6, end_time - start_time) if start_time else 0.0
75
- gen_token_count = len(tokenizer(partial_text, add_special_tokens=False).input_ids)
76
- tps = (gen_token_count / duration) if duration > 0 else 0.0
77
  yield partial_text + f"\n\n⚡ Hız: {tps:.2f} token/sn"
78
 
79
  demo = gr.ChatInterface(
@@ -82,11 +110,14 @@ demo = gr.ChatInterface(
82
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
83
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
84
  gr.Slider(minimum=0.0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
85
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
86
- ],
87
  )
88
 
89
  if __name__ == "__main__":
90
  with torch.inference_mode():
91
- _ = model.generate(**tokenizer(["Hi"], return_tensors="pt").to(model.device), max_new_tokens=1, do_sample=False, use_cache=True)
92
- demo.queue().launch()
 
 
 
 
3
  import threading
4
  import torch
5
  import gradio as gr
6
+ from huggingface_hub import snapshot_download
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
 
9
+ MODEL_REPO = "daniel-dona/gemma-3-270m-it"
10
+ LOCAL_DIR = os.path.join(os.getcwd(), "local_model")
11
 
12
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
13
  os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
14
  os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
15
  os.environ.setdefault("OMP_PROC_BIND", "TRUE")
 
18
  torch.set_num_interop_threads(1)
19
  torch.set_float32_matmul_precision("high")
20
 
21
+ def ensure_local_model(repo_id: str, local_dir: str, tries: int = 3, sleep_s: float = 3.0) -> str:
22
+ os.makedirs(local_dir, exist_ok=True)
23
+ for i in range(tries):
24
+ try:
25
+ snapshot_download(
26
+ repo_id=repo_id,
27
+ local_dir=local_dir,
28
+ local_dir_use_symlinks=False,
29
+ resume_download=True,
30
+ allow_patterns=["*.json", "*.model", "*.safetensors", "*.bin", "*.txt", "*.py"]
31
+ )
32
+ return local_dir
33
+ except Exception:
34
+ if i == tries - 1:
35
+ raise
36
+ time.sleep(sleep_s * (2 ** i))
37
+ return local_dir
38
+
39
+ model_path = ensure_local_model(MODEL_REPO, LOCAL_DIR)
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
42
  model = AutoModelForCausalLM.from_pretrained(
43
+ model_path,
44
+ local_files_only=True,
45
  torch_dtype=torch.float32,
46
  device_map=None
47
  )
 
75
  temperature=temperature if do_sample else None,
76
  use_cache=True,
77
  eos_token_id=tokenizer.eos_token_id,
78
+ pad_token_id=tokenizer.eos_token_id
79
  )
80
  try:
81
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
82
  except TypeError:
83
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
84
+ thread = threading.Thread(
85
+ target=model.generate,
86
+ kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
87
+ )
88
  partial_text = ""
89
+ token_count = 0
90
  start_time = None
91
  with torch.inference_mode():
92
  thread.start()
 
95
  if start_time is None:
96
  start_time = time.time()
97
  partial_text += chunk
98
+ token_count += 1
99
  yield partial_text
100
  finally:
101
  thread.join()
102
  end_time = time.time() if start_time is not None else time.time()
103
  duration = max(1e-6, end_time - start_time) if start_time else 0.0
104
+ tps = (token_count / duration) if duration > 0 else 0.0
 
105
  yield partial_text + f"\n\n⚡ Hız: {tps:.2f} token/sn"
106
 
107
  demo = gr.ChatInterface(
 
110
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
111
  gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
112
  gr.Slider(minimum=0.0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
113
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
114
+ ]
115
  )
116
 
117
  if __name__ == "__main__":
118
  with torch.inference_mode():
119
+ _ = model.generate(
120
+ **tokenizer(["Hi"], return_tensors="pt").to(model.device),
121
+ max_new_tokens=1, do_sample=False, use_cache=True
122
+ )
123
+ demo.queue(concurrency_count=1, max_size=32).launch()