nixaut-codelabs commited on
Commit
2793604
·
verified ·
1 Parent(s): 9de51df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -29
app.py CHANGED
@@ -9,7 +9,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
9
  MODEL_REPO = "daniel-dona/gemma-3-270m-it"
10
  LOCAL_DIR = os.path.join(os.getcwd(), "local_model")
11
 
12
- # CPU optimizasyonları
13
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
14
  os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
15
  os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
@@ -41,9 +40,6 @@ model_path = ensure_local_model(MODEL_REPO, LOCAL_DIR)
41
 
42
  tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
43
 
44
- ### DEĞİŞİKLİK BURADA: ŞABLON BASİTLEŞTİRİLDİ ###
45
- # 'raise_exception' komutunu içermeyen, eski transformers versiyonlarıyla uyumlu şablon.
46
- # Zaten kodumuz şablonu doğru formatta beslediği için bu kontrolleri kaldırabiliriz.
47
  gemma_chat_template_simplified = (
48
  "{% for message in messages %}"
49
  "{% if message['role'] == 'user' %}"
@@ -58,10 +54,7 @@ gemma_chat_template_simplified = (
58
  )
59
 
60
  if tokenizer.chat_template is None:
61
- print("Chat template manuel olarak ayarlanıyor (basitleştirilmiş versiyon).")
62
  tokenizer.chat_template = gemma_chat_template_simplified
63
- ### DEĞİŞİKLİK SONA ERDİ ###
64
-
65
 
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_path,
@@ -71,7 +64,6 @@ model = AutoModelForCausalLM.from_pretrained(
71
  )
72
  model.eval()
73
 
74
- # Çok katı moderasyon system prompt
75
  MODERATION_SYSTEM_PROMPT = (
76
  "You are a multilingual content moderation classifier. "
77
  "You MUST respond with exactly one lowercase letter: 's' for safe, 'u' for unsafe. "
@@ -82,12 +74,8 @@ MODERATION_SYSTEM_PROMPT = (
82
  )
83
 
84
  def build_prompt(message, max_ctx_tokens=128):
85
- # Sistem mesajını ilk kullanıcı mesajının bir parçası haline getiriyoruz.
86
  full_user_message = f"{MODERATION_SYSTEM_PROMPT}\n\nUser input: '{message}'"
87
-
88
- messages = [
89
- {"role": "user", "content": full_user_message}
90
- ]
91
 
92
  text = tokenizer.apply_chat_template(
93
  messages,
@@ -96,7 +84,7 @@ def build_prompt(message, max_ctx_tokens=128):
96
  )
97
 
98
  while len(tokenizer(text, add_special_tokens=False).input_ids) > max_ctx_tokens and len(full_user_message) > 100:
99
- full_user_message = full_user_message[:len(full_user_message)-50]
100
  messages[0]['content'] = full_user_message
101
  text = tokenizer.apply_chat_template(
102
  messages,
@@ -106,15 +94,14 @@ def build_prompt(message, max_ctx_tokens=128):
106
  return text
107
 
108
  def enforce_s_u(text: str) -> str:
109
- """Model çıktısını kesin olarak 's' veya 'u' ile sınırla."""
110
  text_lower = text.strip().lower()
111
- if "u" in text_lower and not "s" in text_lower:
112
  return "u"
113
  if "unsafe" in text_lower:
114
  return "u"
115
  return "s"
116
 
117
- def respond_stream(message, history, max_tokens, temperature, top_p):
118
  text = build_prompt(message)
119
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
120
  do_sample = bool(temperature and temperature > 0.0)
@@ -148,6 +135,7 @@ def respond_stream(message, history, max_tokens, temperature, top_p):
148
  start_time = time.time()
149
  partial_text += chunk
150
  token_count += 1
 
151
  finally:
152
  thread.join()
153
 
@@ -155,18 +143,49 @@ def respond_stream(message, history, max_tokens, temperature, top_p):
155
  end_time = time.time() if start_time else time.time()
156
  duration = max(1e-6, end_time - start_time)
157
  tps = token_count / duration if duration > 0 else 0.0
158
- yield f"{final_label}\n\n⚡ Speed: {tps:.2f} token/s"
159
-
160
- demo = gr.ChatInterface(
161
- respond_stream,
162
- additional_inputs=[
163
- gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Max new tokens"),
164
- gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature"),
165
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
166
- ],
167
- title="Strict Multilingual Moderation Classifier (s/u)",
168
- description="Enter any text in any language. The model will output only 's' (safe) or 'u' (unsafe)."
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
  with torch.inference_mode():
 
9
  MODEL_REPO = "daniel-dona/gemma-3-270m-it"
10
  LOCAL_DIR = os.path.join(os.getcwd(), "local_model")
11
 
 
12
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
13
  os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
14
  os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
 
40
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
42
 
 
 
 
43
  gemma_chat_template_simplified = (
44
  "{% for message in messages %}"
45
  "{% if message['role'] == 'user' %}"
 
54
  )
55
 
56
  if tokenizer.chat_template is None:
 
57
  tokenizer.chat_template = gemma_chat_template_simplified
 
 
58
 
59
  model = AutoModelForCausalLM.from_pretrained(
60
  model_path,
 
64
  )
65
  model.eval()
66
 
 
67
  MODERATION_SYSTEM_PROMPT = (
68
  "You are a multilingual content moderation classifier. "
69
  "You MUST respond with exactly one lowercase letter: 's' for safe, 'u' for unsafe. "
 
74
  )
75
 
76
  def build_prompt(message, max_ctx_tokens=128):
 
77
  full_user_message = f"{MODERATION_SYSTEM_PROMPT}\n\nUser input: '{message}'"
78
+ messages = [{"role": "user", "content": full_user_message}]
 
 
 
79
 
80
  text = tokenizer.apply_chat_template(
81
  messages,
 
84
  )
85
 
86
  while len(tokenizer(text, add_special_tokens=False).input_ids) > max_ctx_tokens and len(full_user_message) > 100:
87
+ full_user_message = full_user_message[:-50]
88
  messages[0]['content'] = full_user_message
89
  text = tokenizer.apply_chat_template(
90
  messages,
 
94
  return text
95
 
96
  def enforce_s_u(text: str) -> str:
 
97
  text_lower = text.strip().lower()
98
+ if "u" in text_lower and "s" not in text_lower:
99
  return "u"
100
  if "unsafe" in text_lower:
101
  return "u"
102
  return "s"
103
 
104
+ def classify_text_stream(message, max_tokens, temperature, top_p):
105
  text = build_prompt(message)
106
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
107
  do_sample = bool(temperature and temperature > 0.0)
 
135
  start_time = time.time()
136
  partial_text += chunk
137
  token_count += 1
138
+ yield partial_text
139
  finally:
140
  thread.join()
141
 
 
143
  end_time = time.time() if start_time else time.time()
144
  duration = max(1e-6, end_time - start_time)
145
  tps = token_count / duration if duration > 0 else 0.0
146
+ yield f"{final_label}\n\n⚡ Speed: {tps:.2f} tokens/s"
147
+
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown("# Multilingual Content Moderation Classifier")
150
+ gr.Markdown("Enter any text to classify it as safe ('s') or unsafe ('u').")
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=2):
154
+ text_input = gr.Textbox(
155
+ label="Text to Classify",
156
+ lines=5,
157
+ placeholder="Enter text in any language..."
158
+ )
159
+ submit_button = gr.Button("Classify", variant="primary")
160
+
161
+ with gr.Column(scale=1):
162
+ text_output = gr.Textbox(label="Classification Result", interactive=False)
163
+ with gr.Accordion("Advanced Settings", open=False):
164
+ max_tokens_slider = gr.Slider(
165
+ minimum=1, maximum=4, value=1, step=1, label="Max New Tokens"
166
+ )
167
+ temp_slider = gr.Slider(
168
+ minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature"
169
+ )
170
+ top_p_slider = gr.Slider(
171
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"
172
+ )
173
+
174
+ gr.Examples(
175
+ examples=[
176
+ ["Hello, how are you today?"],
177
+ ["I will find you and hurt you."],
178
+ ["C'est une belle journée pour apprendre le codage."],
179
+ ["I want to die."],
180
+ ],
181
+ inputs=text_input
182
+ )
183
+
184
+ submit_button.click(
185
+ fn=classify_text_stream,
186
+ inputs=[text_input, max_tokens_slider, temp_slider, top_p_slider],
187
+ outputs=text_output
188
+ )
189
 
190
  if __name__ == "__main__":
191
  with torch.inference_mode():