File size: 3,661 Bytes
ea91575
 
 
 
 
 
 
72e17c4
ea91575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e17c4
ea91575
 
72e17c4
 
ea91575
 
72e17c4
 
 
 
 
 
ea91575
 
 
 
72e17c4
ea91575
 
 
 
72e17c4
ea91575
72e17c4
 
 
 
 
ea91575
 
 
 
72e17c4
ea91575
72e17c4
 
 
 
ea91575
 
 
 
 
 
 
 
 
72e17c4
ea91575
72e17c4
 
 
 
ea91575
72e17c4
ea91575
72e17c4
 
 
 
 
 
ea91575
72e17c4
ea91575
 
72e17c4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
import torch
import torch.nn as nn
import re

# 1. Architecture Patch (RMSNorm)
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.scale

def replace_layernorm_with_rmsnorm(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.LayerNorm):
            dim = child.normalized_shape[0] if isinstance(child.normalized_shape, (tuple, list)) else child.normalized_shape
            rms = RMSNorm(dim=dim, eps=1e-6)
            setattr(module, name, rms)
        else:
            replace_layernorm_with_rmsnorm(child)

# 2. Glue Logic (Prefixes + Punctuation)
def fix_arabic_output(text):
    if not text: return text
    
    # A. Glue Prefixes (Al, Lil, Wa-Al, Bi-Al)
    prefix_pattern = r'(^|\s)(ال|لل|وال|بال)\s+(?=\S)'
    text = re.sub(prefix_pattern, r'\1\2', text)
    text = re.sub(prefix_pattern, r'\1\2', text) # Twice for safety
    
    # B. Glue Punctuation (Remove space before punctuation)
    punctuation_marks = r'[،؟!.,]'
    text = re.sub(r'\s+(' + punctuation_marks + ')', r'\1', text)
    
    return text.strip()

class EndpointHandler:
    def __init__(self, path=""):
        # Load Config & Skeleton
        config = AutoConfig.from_pretrained(path)
        self.model = AutoModelForSeq2SeqLM.from_config(config)
        replace_layernorm_with_rmsnorm(self.model)
        
        # Load Weights Safely
        try:
            # Try standard load first
            self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
            replace_layernorm_with_rmsnorm(self.model)
        except:
             # Fallback: Load state dict manually with strict=False
             from safetensors.torch import load_file
             import os
             w_path = os.path.join(path, "model.safetensors")
             if os.path.exists(w_path):
                 state_dict = load_file(w_path)
             else:
                 state_dict = torch.load(os.path.join(path, "pytorch_model.bin"), map_location="cpu")
            
             # --- SETTING STRICT=FALSE AS REQUESTED ---
             self.model.load_state_dict(state_dict, strict=False)
        
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device).eval()

    def __call__(self, data: Any) -> List[Dict[str, Any]]:
        inputs = data.pop("inputs", data)
        if isinstance(inputs, str): inputs = [inputs]

        # Tokenize
        tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
        
        # Remove harmful args
        if "token_type_ids" in tokenized_inputs: 
            del tokenized_inputs["token_type_ids"]

        # Generate
        with torch.no_grad():
            generated_ids = self.model.generate(
                **tokenized_inputs,
                max_new_tokens=128,
                num_beams=5,
                early_stopping=True
            )

        # Decode & Fix
        decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        final_outputs = [fix_arabic_output(text) for text in decoded_outputs]

        return [{"generated_text": text} for text in final_outputs]