Spaces:
Sleeping
Sleeping
File size: 7,491 Bytes
54ba978 cc6f54d 692a239 cc6f54d 73f52cd 5cecbb0 692a239 73f52cd 5cecbb0 73f52cd 6fe1066 5cecbb0 b57189b 54ba978 bec8f6d 5cecbb0 b57189b cc6f54d 5cecbb0 54ba978 cc6f54d 1d7d3eb 6fe1066 1d7d3eb 6fe1066 1d7d3eb b57189b 1d7d3eb b57189b 6fe1066 b57189b 1d7d3eb 6fe1066 1d7d3eb 6fe1066 b57189b b14c8d8 1d7d3eb b14c8d8 6fe1066 b14c8d8 1d7d3eb 6fe1066 4bdd945 54ba978 cc6f54d b2d905e cc6f54d 73f52cd cc6f54d 54ba978 4bdd945 5cecbb0 fd78eab 4bdd945 54ba978 4bdd945 6fe1066 5cecbb0 6fe1066 363f32b 5cecbb0 54ba978 cc6f54d 692a239 54ba978 5cecbb0 54ba978 6fe1066 1d7d3eb 6fe1066 1d7d3eb fd78eab 73f52cd 692a239 b57189b 73f52cd 1d7d3eb 6fe1066 5cecbb0 6fe1066 1d7d3eb 73f52cd 54ba978 5cecbb0 b57189b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import time
import threading
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
MODEL_REPO = "daniel-dona/gemma-3-270m-it"
LOCAL_DIR = os.path.join(os.getcwd(), "local_model")
# CPU optimizasyonları
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
os.environ.setdefault("OMP_PROC_BIND", "TRUE")
torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
torch.set_num_interop_threads(1)
torch.set_float32_matmul_precision("high")
def ensure_local_model(repo_id: str, local_dir: str, tries: int = 3, sleep_s: float = 3.0) -> str:
os.makedirs(local_dir, exist_ok=True)
for i in range(tries):
try:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.model", "*.safetensors", "*.bin", "*.txt", "*.py"]
)
return local_dir
except Exception:
if i == tries - 1:
raise
time.sleep(sleep_s * (2 ** i))
return local_dir
model_path = ensure_local_model(MODEL_REPO, LOCAL_DIR)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
### DEĞİŞİKLİK BURADA BAŞLIYOR ###
# Gemma modelleri için sohbet şablonunu manuel olarak ayarlıyoruz.
# Bu, eski transformers versiyonlarında veya ortam sorunlarında hatayı önler.
gemma_chat_template = (
"{% if messages[0]['role'] == 'system' %}"
"{% raise_exception('System messages are not supported by this model!') %}"
"{% endif %}"
"{% for message in messages %}"
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ '<start_of_turn>user\\n' + message['content'] | trim + '<end_of_turn>\\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '<start_of_turn>model\\n' + message['content'] | trim + '<end_of_turn>\\n' }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<start_of_turn>model\\n' }}"
"{% endif %}"
)
# Not: Modelin kendi tokenizer_config.json dosyasında sistem mesajları desteklenmediği belirtiliyor.
# Bu yüzden yukarıdaki şablonda sistem mesajı gelirse hata vermesi sağlanmıştır.
# build_prompt fonksiyonunu da buna göre güncelleyeceğiz.
if tokenizer.chat_template is None:
print("Chat template manuel olarak ayarlanıyor.")
tokenizer.chat_template = gemma_chat_template
### DEĞİŞİKLİK BURADA BİTİYOR ###
model = AutoModelForCausalLM.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch.float32,
device_map=None
)
model.eval()
# Çok katı moderasyon system prompt (yalnızca "s" veya "u")
MODERATION_SYSTEM_PROMPT = (
"You are a multilingual content moderation classifier. "
"You MUST respond with exactly one lowercase letter: 's' for safe, 'u' for unsafe. "
"No explanations, no punctuation, no extra words. "
"If the message contains hate speech, harassment, sexual content involving minors, "
"extreme violence, self-harm encouragement, or other unsafe material, respond 'u'. "
"Otherwise respond 's'."
)
def build_prompt(message, max_ctx_tokens=128):
# ### DEĞİŞİKLİK 2: Sistem mesajı artık desteklenmiyor ###
# Sistem mesajını ilk kullanıcı mesajının bir parçası haline getiriyoruz.
# Bu, Gemma'nın talimat takip etme (instruction-tuned) yapısına daha uygundur.
full_user_message = f"{MODERATION_SYSTEM_PROMPT}\n\nUser input: '{message}'"
messages = [
{"role": "user", "content": full_user_message}
]
# Doğru chat template kullanımı
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Token sınırını aşarsa kısalt (Bu senaryoda pek olası değil ama iyi bir pratik)
while len(tokenizer(text, add_special_tokens=False).input_ids) > max_ctx_tokens and len(full_user_message) > 100:
full_user_message = full_user_message[:len(full_user_message)-50] # Mesajı sondan kısalt
messages[0]['content'] = full_user_message
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return text
def enforce_s_u(text: str) -> str:
"""Model çıktısını kesin olarak 's' veya 'u' ile sınırla."""
text_lower = text.strip().lower()
if "u" in text_lower and not "s" in text_lower:
return "u"
if "unsafe" in text_lower:
return "u"
return "s"
def respond_stream(message, history, max_tokens, temperature, top_p):
text = build_prompt(message)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
do_sample = bool(temperature and temperature > 0.0)
gen_kwargs = dict(
max_new_tokens=max_tokens,
do_sample=do_sample,
top_p=top_p,
temperature=temperature if do_sample else None,
use_cache=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
try:
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
except TypeError:
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
thread = threading.Thread(
target=model.generate,
kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
)
partial_text = ""
token_count = 0
start_time = None
with torch.inference_mode():
thread.start()
try:
for chunk in streamer:
if start_time is None:
start_time = time.time()
partial_text += chunk
token_count += 1
finally:
thread.join()
# Çıktıyı kesin olarak s/u'ya indir
final_label = enforce_s_u(partial_text)
end_time = time.time() if start_time else time.time()
duration = max(1e-6, end_time - start_time)
tps = token_count / duration if duration > 0 else 0.0
yield f"{final_label}\n\n⚡ Speed: {tps:.2f} token/s"
demo = gr.ChatInterface(
respond_stream,
chatbot=False, # Bu parametre ChatInterface'de bulunmuyor, chatbot'u gizlemek için temayı kullanabilirsiniz.
additional_inputs=[
gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
],
title="Strict Multilingual Moderation Classifier (s/u)",
description="Enter any text in any language. The model will output only 's' (safe) or 'u' (unsafe)."
)
if __name__ == "__main__":
with torch.inference_mode():
_ = model.generate(
**tokenizer(["Hi"], return_tensors="pt").to(model.device),
max_new_tokens=1, do_sample=False, use_cache=True
)
demo.queue(max_size=32).launch() |