thebajajra commited on
Commit
dec3b0c
·
verified ·
1 Parent(s): 3e283cb

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +7 -6
  2. model.safetensors +2 -2
  3. modeling_gemma3_biencoder.py +29 -49
config.json CHANGED
@@ -1,19 +1,19 @@
1
  {
2
  "_sliding_window_pattern": 6,
3
  "architectures": [
4
- "Gemma3EncoderForMaskedLM"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "attn_logit_softcapping": null,
9
  "auto_map": {
10
- "AutoModel": "modeling_gemma3_biencoder.Gemma3EncoderForMaskedLM",
11
  "AutoModelForMaskedLM": "modeling_gemma3_biencoder.Gemma3EncoderForMaskedLM",
12
  "AutoModelForSequenceClassification": "modeling_gemma3_biencoder.Gemma3EncoderForSequenceClassification",
13
  "AutoModelForTokenClassification": "modeling_gemma3_biencoder.Gemma3EncoderForTokenClassification"
14
  },
15
  "bos_token_id": 2,
16
- "dtype": "bfloat16",
17
  "eos_token_id": 1,
18
  "final_logit_softcapping": null,
19
  "head_dim": 256,
@@ -41,7 +41,7 @@
41
  "sliding_attention",
42
  "full_attention"
43
  ],
44
- "max_position_embeddings": 2048,
45
  "model_type": "gemma3_text",
46
  "num_attention_heads": 4,
47
  "num_hidden_layers": 18,
@@ -56,5 +56,6 @@
56
  "transformers_version": "4.57.3",
57
  "use_bidirectional_attention": true,
58
  "use_cache": false,
59
- "vocab_size": 262145
60
- }
 
 
1
  {
2
  "_sliding_window_pattern": 6,
3
  "architectures": [
4
+ "Gemma3EncoderModel"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "attn_logit_softcapping": null,
9
  "auto_map": {
10
+ "AutoModel": "modeling_gemma3_biencoder.Gemma3EncoderModel",
11
  "AutoModelForMaskedLM": "modeling_gemma3_biencoder.Gemma3EncoderForMaskedLM",
12
  "AutoModelForSequenceClassification": "modeling_gemma3_biencoder.Gemma3EncoderForSequenceClassification",
13
  "AutoModelForTokenClassification": "modeling_gemma3_biencoder.Gemma3EncoderForTokenClassification"
14
  },
15
  "bos_token_id": 2,
16
+ "dtype": "float32",
17
  "eos_token_id": 1,
18
  "final_logit_softcapping": null,
19
  "head_dim": 256,
 
41
  "sliding_attention",
42
  "full_attention"
43
  ],
44
+ "max_position_embeddings": 32768,
45
  "model_type": "gemma3_text",
46
  "num_attention_heads": 4,
47
  "num_hidden_layers": 18,
 
56
  "transformers_version": "4.57.3",
57
  "use_bidirectional_attention": true,
58
  "use_cache": false,
59
+ "vocab_size": 262145,
60
+ "attn_implementation": null
61
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b85687659049cd09b6da529da4b4190ba7c00528c17d7065ce8e40ac850a33da
3
- size 536224808
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9dc0cc7558bc128b11280fbdbacf630a260a637110ad69d3de2f03ca9650093
3
+ size 1072422288
modeling_gemma3_biencoder.py CHANGED
@@ -11,61 +11,41 @@ from transformers.models.gemma3.modeling_gemma3 import (
11
  Gemma3TextModel,
12
  )
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Gemma3EncoderForMaskedLM(Gemma3PreTrainedModel):
15
  config_class = Gemma3TextConfig
16
  base_model_prefix = "encoder"
17
  _tied_weights_keys = ["lm_head.weight"]
18
  _keys_to_ignore_on_load_missing = [r"lm_head\.weight"]
19
 
20
- @classmethod
21
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
22
- """Override to preserve sliding_window from config.json."""
23
- import json
24
- import os
25
-
26
- # Read original sliding_window from config.json before it gets modified
27
- original_sliding_window = None
28
- try:
29
- # Use transformers utility to resolve config path (handles both local and Hub)
30
- from transformers.utils import CONFIG_NAME
31
- from transformers.utils.hub import cached_file
32
-
33
- config_path = None
34
- if os.path.isdir(pretrained_model_name_or_path):
35
- # Local path
36
- config_path = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
37
- else:
38
- # Hub path - this will download/cache if needed
39
- try:
40
- config_path = cached_file(
41
- pretrained_model_name_or_path,
42
- CONFIG_NAME,
43
- cache_dir=kwargs.get("cache_dir"),
44
- force_download=kwargs.get("force_download", False),
45
- resume_download=kwargs.get("resume_download", False),
46
- )
47
- except Exception:
48
- pass
49
-
50
- if config_path and os.path.exists(config_path):
51
- with open(config_path, "r", encoding="utf-8") as f:
52
- config_dict = json.load(f)
53
- original_sliding_window = config_dict.get("sliding_window")
54
- except Exception:
55
- # If we can't read the config, continue anyway
56
- pass
57
-
58
- # Load model normally
59
- model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
60
-
61
- # Restore original sliding_window if it was modified by Gemma3TextModel
62
- if original_sliding_window is not None:
63
- current_sw = getattr(model.config, "sliding_window", None)
64
- if current_sw != original_sliding_window:
65
- model.config.sliding_window = original_sliding_window
66
-
67
- return model
68
-
69
  def __init__(self, config: Gemma3TextConfig):
70
  cfg = copy.deepcopy(config)
71
  if hasattr(cfg, "use_bidirectional_attention"):
 
11
  Gemma3TextModel,
12
  )
13
 
14
+ class Gemma3EncoderModel(Gemma3PreTrainedModel):
15
+ config_class = Gemma3TextConfig
16
+ base_model_prefix = "encoder"
17
+
18
+ def __init__(self, config):
19
+ cfg = copy.deepcopy(config)
20
+ if hasattr(cfg, "use_bidirectional_attention"):
21
+ cfg.use_bidirectional_attention = True
22
+ cfg.use_cache = False
23
+ super().__init__(cfg)
24
+ self.encoder = Gemma3TextModel(cfg)
25
+ self.post_init()
26
+
27
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None,
28
+ inputs_embeds=None, output_attentions=None, output_hidden_states=None,
29
+ return_dict=True, **kwargs):
30
+ return self.encoder(
31
+ input_ids=input_ids,
32
+ attention_mask=attention_mask,
33
+ position_ids=position_ids,
34
+ inputs_embeds=inputs_embeds,
35
+ use_cache=False,
36
+ is_causal=False,
37
+ output_attentions=output_attentions,
38
+ output_hidden_states=output_hidden_states,
39
+ return_dict=return_dict,
40
+ **kwargs,
41
+ )
42
+
43
  class Gemma3EncoderForMaskedLM(Gemma3PreTrainedModel):
44
  config_class = Gemma3TextConfig
45
  base_model_prefix = "encoder"
46
  _tied_weights_keys = ["lm_head.weight"]
47
  _keys_to_ignore_on_load_missing = [r"lm_head\.weight"]
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def __init__(self, config: Gemma3TextConfig):
50
  cfg = copy.deepcopy(config)
51
  if hasattr(cfg, "use_bidirectional_attention"):