Shams03 commited on
Commit
72e17c4
·
verified ·
1 Parent(s): 3663aa0

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +37 -16
handler.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
  import re
7
 
8
- # 1. Define Architecture Patch
9
  class RMSNorm(nn.Module):
10
  def __init__(self, dim: int, eps: float = 1e-6):
11
  super().__init__()
@@ -25,36 +25,45 @@ def replace_layernorm_with_rmsnorm(module: nn.Module):
25
  else:
26
  replace_layernorm_with_rmsnorm(child)
27
 
28
- # 2. Define The Glue Logic (UPDATED)
29
  def fix_arabic_output(text):
30
  if not text: return text
31
- # Glue Prefixes (Next word)
 
32
  prefix_pattern = r'(^|\s)(ال|لل|وال|بال)\s+(?=\S)'
33
  text = re.sub(prefix_pattern, r'\1\2', text)
34
- text = re.sub(prefix_pattern, r'\1\2', text)
35
- # Glue Punctuation (Previous word)
36
- punctuation_pattern = r'\s+([،؟!.,])'
37
- text = re.sub(punctuation_pattern, r'\1', text)
 
 
38
  return text.strip()
39
 
40
  class EndpointHandler:
41
  def __init__(self, path=""):
 
42
  config = AutoConfig.from_pretrained(path)
43
  self.model = AutoModelForSeq2SeqLM.from_config(config)
44
  replace_layernorm_with_rmsnorm(self.model)
45
 
 
46
  try:
 
 
 
 
 
47
  from safetensors.torch import load_file
48
  import os
49
  w_path = os.path.join(path, "model.safetensors")
50
  if os.path.exists(w_path):
51
- self.model.load_state_dict(load_file(w_path), strict=False) #Must be false
52
  else:
53
- self.model.load_state_dict(torch.load(os.path.join(path, "pytorch_model.bin"), map_location="cpu"), strict=True)
54
- except:
55
- # Fallback
56
- self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
57
- replace_layernorm_with_rmsnorm(self.model)
58
 
59
  self.tokenizer = AutoTokenizer.from_pretrained(path)
60
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -64,12 +73,24 @@ class EndpointHandler:
64
  inputs = data.pop("inputs", data)
65
  if isinstance(inputs, str): inputs = [inputs]
66
 
 
67
  tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
68
- if "token_type_ids" in tokenized_inputs: del tokenized_inputs["token_type_ids"]
 
 
 
69
 
 
70
  with torch.no_grad():
71
- generated_ids = self.model.generate(**tokenized_inputs, max_new_tokens=128, num_beams=5, early_stopping=True)
 
 
 
 
 
72
 
 
73
  decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
74
  final_outputs = [fix_arabic_output(text) for text in decoded_outputs]
75
- return [{"generated_text": text} for text in final_outputs]
 
 
5
  import torch.nn as nn
6
  import re
7
 
8
+ # 1. Architecture Patch (RMSNorm)
9
  class RMSNorm(nn.Module):
10
  def __init__(self, dim: int, eps: float = 1e-6):
11
  super().__init__()
 
25
  else:
26
  replace_layernorm_with_rmsnorm(child)
27
 
28
+ # 2. Glue Logic (Prefixes + Punctuation)
29
  def fix_arabic_output(text):
30
  if not text: return text
31
+
32
+ # A. Glue Prefixes (Al, Lil, Wa-Al, Bi-Al)
33
  prefix_pattern = r'(^|\s)(ال|لل|وال|بال)\s+(?=\S)'
34
  text = re.sub(prefix_pattern, r'\1\2', text)
35
+ text = re.sub(prefix_pattern, r'\1\2', text) # Twice for safety
36
+
37
+ # B. Glue Punctuation (Remove space before punctuation)
38
+ punctuation_marks = r'[،؟!.,]'
39
+ text = re.sub(r'\s+(' + punctuation_marks + ')', r'\1', text)
40
+
41
  return text.strip()
42
 
43
  class EndpointHandler:
44
  def __init__(self, path=""):
45
+ # Load Config & Skeleton
46
  config = AutoConfig.from_pretrained(path)
47
  self.model = AutoModelForSeq2SeqLM.from_config(config)
48
  replace_layernorm_with_rmsnorm(self.model)
49
 
50
+ # Load Weights Safely
51
  try:
52
+ # Try standard load first
53
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
54
+ replace_layernorm_with_rmsnorm(self.model)
55
+ except:
56
+ # Fallback: Load state dict manually with strict=False
57
  from safetensors.torch import load_file
58
  import os
59
  w_path = os.path.join(path, "model.safetensors")
60
  if os.path.exists(w_path):
61
+ state_dict = load_file(w_path)
62
  else:
63
+ state_dict = torch.load(os.path.join(path, "pytorch_model.bin"), map_location="cpu")
64
+
65
+ # --- SETTING STRICT=FALSE AS REQUESTED ---
66
+ self.model.load_state_dict(state_dict, strict=False)
 
67
 
68
  self.tokenizer = AutoTokenizer.from_pretrained(path)
69
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
73
  inputs = data.pop("inputs", data)
74
  if isinstance(inputs, str): inputs = [inputs]
75
 
76
+ # Tokenize
77
  tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
78
+
79
+ # Remove harmful args
80
+ if "token_type_ids" in tokenized_inputs:
81
+ del tokenized_inputs["token_type_ids"]
82
 
83
+ # Generate
84
  with torch.no_grad():
85
+ generated_ids = self.model.generate(
86
+ **tokenized_inputs,
87
+ max_new_tokens=128,
88
+ num_beams=5,
89
+ early_stopping=True
90
+ )
91
 
92
+ # Decode & Fix
93
  decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
94
  final_outputs = [fix_arabic_output(text) for text in decoded_outputs]
95
+
96
+ return [{"generated_text": text} for text in final_outputs]