YugoGPT / handler.py
Deniss8686's picture
Update handler.py
1da0626 verified
Raw
History Blame Contribute Delete
9.32 kB
from typing import Any, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import traceback
import time
class EndpointHandler:
def __init__(self, model_dir: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.model.eval()
self.device = next(self.model.parameters()).device
self.default_max_new_tokens = 200
self.context_window = 4096
self.safety_buffer = 50
self.default_generation_config = {
"temperature": 0.68,
"top_p": 0.87,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"repetition_penalty": 1.15
}
if self.tokenizer.bos_token_id is not None:
self.default_generation_config["bos_token_id"] = self.tokenizer.bos_token_id
def _build_prompt(self, messages: List[Dict[str, str]], user_system: str = None) -> str:
"""
Baut den Prompt so, dass der bosnische Sprachbefehl IMMER an erster
Stelle steht und niemals überschrieben werden kann.
"""
# Unverrückbarer Base-Prompt mit Sprachvorgabe
base_prompt = """Ti si profesionalan i brz Telesales agent za Bosnu.
Govori ISKLJUČIVO bosanski/srpski/hrvatski jezik. Nikad ne mijenjaj jezik.
Budi direktan, ljubazan i prodajno jak.
Ne pričaj duge priče. Odmah idi na stvar i pokušaj zatvoriti prodaju."""
# User-System-Message wird nur ERGÄNZT, nie ersetzt
if user_system:
base_prompt += f"\n\nDodatne instrukcije: {user_system}"
# Chat-Verlauf (nur user / assistant, system wird oben behandelt)
prompt = base_prompt + "\n\n"
for msg in messages[-10:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "user":
prompt += f"Kupac: {content}\n"
else:
prompt += f"Agent: {content}\n"
prompt += "Agent:"
return prompt
def _clean_response(self, text: str) -> str:
"""
Schneidet halluzinierte Fortsetzungen ab, sobald das Modell
versucht, eine neue Rolle (Kupac/Agent) zu starten.
"""
text = text.strip()
# Falls Antwort mit "Agent:" beginnt, entferne es
if text.lower().startswith("agent:"):
text = text[6:].strip()
# Stop-Strings: Alles ab erstem Auftreten abschneiden
stop_markers = ["\nKupac:", "\nAgent:", "Kupac:", "Agent:"]
for marker in stop_markers:
if marker in text:
text = text.split(marker)[0].strip()
break
return text
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
# ==================== INPUT PARSING ====================
messages = []
user_system = None
if "messages" in data and isinstance(data["messages"], list):
for msg in data["messages"]:
if not isinstance(msg, dict):
continue
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
user_system = content
else:
messages.append({"role": role, "content": content})
elif "inputs" in data:
raw = data["inputs"]
if isinstance(raw, str):
messages = [{"role": "user", "content": raw}]
elif isinstance(raw, list):
for item in raw:
if isinstance(item, dict):
messages.append(item)
else:
messages.append({"role": "user", "content": str(item)})
else:
messages = [{"role": "user", "content": str(raw)}]
if not messages or all(m.get("content", "").strip() == "" for m in messages):
return self._build_error_response("Žao mi je, nisam razumio vašu poruku.")
# ==================== PARAMETER ====================
max_new_tokens = data.get("max_tokens", data.get("max_new_tokens", self.default_max_new_tokens))
if not isinstance(max_new_tokens, int) or max_new_tokens <= 0:
max_new_tokens = self.default_max_new_tokens
max_input_length = self.context_window - max_new_tokens - self.safety_buffer
if max_input_length <= 0:
max_input_length = self.context_window // 2
gen_config = {
"max_new_tokens": max_new_tokens,
"temperature": float(data.get("temperature", self.default_generation_config["temperature"])),
"top_p": float(data.get("top_p", self.default_generation_config["top_p"])),
"do_sample": True,
"pad_token_id": self.default_generation_config["pad_token_id"],
"eos_token_id": self.default_generation_config["eos_token_id"],
"repetition_penalty": float(data.get("repetition_penalty", self.default_generation_config["repetition_penalty"])),
}
if "bos_token_id" in self.default_generation_config:
gen_config["bos_token_id"] = self.default_generation_config["bos_token_id"]
# ==================== PROMPT & LÄNGEN-PRÜFUNG ====================
prompt = self._build_prompt(messages, user_system)
# Prüfen, ob Prompt zu lang ist (vor Truncation)
raw_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
was_truncated = raw_token_count > max_input_length
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
return_attention_mask=True
).to(self.device)
input_length = inputs["input_ids"].shape[1]
# ==================== GENERIERUNG ====================
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**gen_config
)
new_tokens = outputs[0][input_length:]
completion_length = len(new_tokens)
raw_response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Stop-Strings anwenden
response = self._clean_response(raw_response)
# ==================== OPENAI RESPONSE + WARNUNG ====================
result = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": data.get("model", "yugogpt-eng"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": input_length,
"completion_tokens": completion_length,
"total_tokens": input_length + completion_length
}
}
# Längen-Warnung mitschicken, falls abgeschnitten
if was_truncated:
result["warning"] = (
f"Kontekst je skracen sa {raw_token_count} na {max_input_tokens} tokena. "
"Rani dio razgovora mozda nedostaje."
)
return result
except Exception:
print(traceback.format_exc())
return self._build_error_response("Žao mi je, došlo je do greške. Možete li ponoviti?")
def _build_error_response(self, message: str) -> Dict[str, Any]:
"""Hilfsfunktion für einheitliche Fehler-Antworten im OpenAI-Format."""
return {
"id": "chatcmpl-error",
"object": "chat.completion",
"created": int(time.time()),
"model": "yugogpt-eng",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": message},
"finish_reason": "stop"
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}