ruhzi commited on
Commit
c0741fa
·
verified ·
1 Parent(s): 9aed480

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -27
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- # SPEED FIX 1: Maximize CPU core usage for Hugging Face Free Tier (2 vCPUs)
3
  os.environ["OMP_NUM_THREADS"] = "2"
4
 
5
  import gradio as gr
@@ -7,76 +6,90 @@ import torch
7
  import gc
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
  from huggingface_hub import hf_hub_download
10
- from threading import Thread
11
 
12
- # SPEED FIX 2: Explicitly tell PyTorch to use both CPU cores
13
  torch.set_num_threads(2)
14
 
15
- model_path = "ruhzi/Indian_History_SLM"
16
-
17
  tokenizer = AutoTokenizer.from_pretrained(model_path)
18
-
19
  template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja")
20
  with open(template_file, "r", encoding="utf-8") as f:
21
  tokenizer.chat_template = f.read()
22
 
23
- # SPEED FIX 3: Removed device_map and used float32 (Native CPU math is faster)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_path,
26
- torch_dtype=torch.float32,
27
  low_cpu_mem_usage=True
28
  )
29
 
 
 
 
30
  def chat_inference(message, history):
 
 
 
 
 
 
 
31
  messages = []
32
-
33
- # MEMORY PROTECTION: Only keep the last 3 conversational turns
34
  recent_history = history[-3:] if len(history) > 3 else history
35
-
36
  for user_msg, assistant_msg in recent_history:
37
  messages.append({"role": "user", "content": user_msg})
38
  messages.append({"role": "assistant", "content": assistant_msg})
39
  messages.append({"role": "user", "content": message})
40
-
41
  input_text = tokenizer.apply_chat_template(
42
  messages,
43
  tokenize=False,
44
  add_generation_prompt=True,
45
- enable_thinking=False
46
  )
47
-
48
- # Explicitly send to CPU
49
  inputs = tokenizer([input_text], return_tensors="pt").to("cpu")
50
 
51
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
52
 
53
  generate_kwargs = dict(
54
  **inputs,
55
  streamer=streamer,
56
- max_new_tokens=512, # SPEED FIX 4: Kept at 512 for faster, punchier demo responses
57
  do_sample=True,
58
  temperature=0.7,
59
  top_p=0.8,
60
  )
61
 
62
- t = Thread(target=model.generate, kwargs=generate_kwargs)
63
  t.start()
64
 
65
  partial_message = ""
66
- for new_token in streamer:
67
- partial_message += new_token
68
- yield partial_message
69
-
70
- # MEMORY PROTECTION: Cleanup after generation finishes
71
- del inputs
72
- gc.collect()
 
 
 
 
 
 
 
73
 
74
  demo = gr.ChatInterface(
75
  fn=chat_inference,
76
  title="Indian History SLM",
77
  description="Ask me anything about Indian History!",
78
- # CRASH PROTECTION: The strict queue. 1 user at a time.
79
- concurrency_limit=1
80
  )
81
 
82
  if __name__ == "__main__":
 
1
  import os
 
2
  os.environ["OMP_NUM_THREADS"] = "2"
3
 
4
  import gradio as gr
 
6
  import gc
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  from huggingface_hub import hf_hub_download
9
+ from threading import Thread, Event
10
 
 
11
  torch.set_num_threads(2)
12
 
13
+ model_path = "ruhzi/Indian_History_SLM"
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
15
  template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja")
16
  with open(template_file, "r", encoding="utf-8") as f:
17
  tokenizer.chat_template = f.read()
18
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_path,
21
+ torch_dtype=torch.float32,
22
  low_cpu_mem_usage=True
23
  )
24
 
25
+ # Global stop event — shared across the current generation
26
+ stop_event = Event()
27
+
28
  def chat_inference(message, history):
29
+ global stop_event
30
+
31
+ # Signal any previous generation to stop, then reset for this run
32
+ stop_event.set()
33
+ stop_event = Event()
34
+ current_stop = stop_event # Capture a local reference for this generation
35
+
36
  messages = []
 
 
37
  recent_history = history[-3:] if len(history) > 3 else history
 
38
  for user_msg, assistant_msg in recent_history:
39
  messages.append({"role": "user", "content": user_msg})
40
  messages.append({"role": "assistant", "content": assistant_msg})
41
  messages.append({"role": "user", "content": message})
42
+
43
  input_text = tokenizer.apply_chat_template(
44
  messages,
45
  tokenize=False,
46
  add_generation_prompt=True,
47
+ enable_thinking=False
48
  )
49
+
 
50
  inputs = tokenizer([input_text], return_tensors="pt").to("cpu")
51
 
52
+ streamer = TextIteratorStreamer(
53
+ tokenizer,
54
+ timeout=60.0,
55
+ skip_prompt=True,
56
+ skip_special_tokens=True
57
+ )
58
 
59
  generate_kwargs = dict(
60
  **inputs,
61
  streamer=streamer,
62
+ max_new_tokens=512,
63
  do_sample=True,
64
  temperature=0.7,
65
  top_p=0.8,
66
  )
67
 
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True)
69
  t.start()
70
 
71
  partial_message = ""
72
+ try:
73
+ for new_token in streamer:
74
+ if current_stop.is_set():
75
+ # Drain the streamer so the thread can exit cleanly
76
+ for _ in streamer:
77
+ pass
78
+ break
79
+ partial_message += new_token
80
+ yield partial_message
81
+ finally:
82
+ # Always clean up, whether generation finished or was stopped
83
+ del inputs
84
+ gc.collect()
85
+
86
 
87
  demo = gr.ChatInterface(
88
  fn=chat_inference,
89
  title="Indian History SLM",
90
  description="Ask me anything about Indian History!",
91
+ stop_btn="Stop", # Renders the Stop button
92
+ concurrency_limit=1,
93
  )
94
 
95
  if __name__ == "__main__":