File size: 3,661 Bytes
ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 ea91575 72e17c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
import torch
import torch.nn as nn
import re
# 1. Architecture Patch (RMSNorm)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x / rms * self.scale
def replace_layernorm_with_rmsnorm(module: nn.Module):
for name, child in list(module.named_children()):
if isinstance(child, nn.LayerNorm):
dim = child.normalized_shape[0] if isinstance(child.normalized_shape, (tuple, list)) else child.normalized_shape
rms = RMSNorm(dim=dim, eps=1e-6)
setattr(module, name, rms)
else:
replace_layernorm_with_rmsnorm(child)
# 2. Glue Logic (Prefixes + Punctuation)
def fix_arabic_output(text):
if not text: return text
# A. Glue Prefixes (Al, Lil, Wa-Al, Bi-Al)
prefix_pattern = r'(^|\s)(ال|لل|وال|بال)\s+(?=\S)'
text = re.sub(prefix_pattern, r'\1\2', text)
text = re.sub(prefix_pattern, r'\1\2', text) # Twice for safety
# B. Glue Punctuation (Remove space before punctuation)
punctuation_marks = r'[،؟!.,]'
text = re.sub(r'\s+(' + punctuation_marks + ')', r'\1', text)
return text.strip()
class EndpointHandler:
def __init__(self, path=""):
# Load Config & Skeleton
config = AutoConfig.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_config(config)
replace_layernorm_with_rmsnorm(self.model)
# Load Weights Safely
try:
# Try standard load first
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
replace_layernorm_with_rmsnorm(self.model)
except:
# Fallback: Load state dict manually with strict=False
from safetensors.torch import load_file
import os
w_path = os.path.join(path, "model.safetensors")
if os.path.exists(w_path):
state_dict = load_file(w_path)
else:
state_dict = torch.load(os.path.join(path, "pytorch_model.bin"), map_location="cpu")
# --- SETTING STRICT=FALSE AS REQUESTED ---
self.model.load_state_dict(state_dict, strict=False)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device).eval()
def __call__(self, data: Any) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
if isinstance(inputs, str): inputs = [inputs]
# Tokenize
tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
# Remove harmful args
if "token_type_ids" in tokenized_inputs:
del tokenized_inputs["token_type_ids"]
# Generate
with torch.no_grad():
generated_ids = self.model.generate(
**tokenized_inputs,
max_new_tokens=128,
num_beams=5,
early_stopping=True
)
# Decode & Fix
decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
final_outputs = [fix_arabic_output(text) for text in decoded_outputs]
return [{"generated_text": text} for text in final_outputs] |