Yong Liu commited on
Commit
fe3660d
·
1 Parent(s): d36359f

update handler

Browse files
Files changed (1) hide show
  1. handler.py +20 -15
handler.py CHANGED
@@ -3,30 +3,35 @@ import json
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, AutoConfig
5
  from typing import Dict, List, Any, Optional, Union
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  # Initialize model and tokenizer
10
  self.model_path = path if path else os.environ.get("MODEL_PATH", "")
11
 
12
- # Fix RoPE scaling configuration
13
  try:
14
- config = AutoConfig.from_pretrained(self.model_path)
 
15
 
16
- # Check if config has rope_scaling attribute and fix the short_factor length
17
- if hasattr(config, "rope_scaling") and "short_factor" in config.rope_scaling:
18
- short_factor = config.rope_scaling["short_factor"]
19
- if len(short_factor) == 48: # If we have the problematic length
20
- print("Fixing rope_scaling short_factor length from 48 to 64")
21
- # Pad to length 64
22
- padded_short_factor = list(short_factor) + [0.0] * (64 - len(short_factor))
23
- config.rope_scaling["short_factor"] = padded_short_factor
24
-
25
- # Save the fixed config
26
- config.save_pretrained(self.model_path)
27
- print("Fixed config saved")
 
 
 
28
  except Exception as e:
29
- print(f"Warning: Could not fix RoPE scaling configuration: {str(e)}")
30
 
31
  # Load tokenizer
32
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
 
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, AutoConfig
5
  from typing import Dict, List, Any, Optional, Union
6
+ import functools
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  # Initialize model and tokenizer
11
  self.model_path = path if path else os.environ.get("MODEL_PATH", "")
12
 
13
+ # Monkey patch the RoPE scaling validation to bypass the length check
14
  try:
15
+ from transformers.models.phi3.configuration_phi3 import Phi3Config
16
+ original_validation = Phi3Config._rope_scaling_validation
17
 
18
+ # Create a patched version that doesn't validate length
19
+ @functools.wraps(original_validation)
20
+ def patched_validation(self_config):
21
+ # Skip validation if short_factor length is 48
22
+ if (hasattr(self_config, "rope_scaling") and
23
+ "short_factor" in self_config.rope_scaling and
24
+ len(self_config.rope_scaling["short_factor"]) == 48):
25
+ print("Bypassing RoPE scaling validation for short_factor of length 48")
26
+ return
27
+ # Otherwise call the original validation
28
+ return original_validation(self_config)
29
+
30
+ # Apply the monkey patch
31
+ Phi3Config._rope_scaling_validation = patched_validation
32
+ print("Successfully patched RoPE scaling validation")
33
  except Exception as e:
34
+ print(f"Warning: Could not patch RoPE scaling validation: {str(e)}")
35
 
36
  # Load tokenizer
37
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)