Navya-Sree commited on
Commit
ac649df
·
verified ·
1 Parent(s): b4768ab

Update cultural_model.py

Browse files
Files changed (1) hide show
  1. cultural_model.py +11 -12
cultural_model.py CHANGED
@@ -1,24 +1,23 @@
1
  from transformers import M2M100ForConditionalGeneration
2
 
3
  class CulturalM2M100(M2M100ForConditionalGeneration):
4
- """
5
- Custom M2M100 model with cultural preservation features
6
- Inherits from M2M100ForConditionalGeneration and adds:
7
- - Cultural context tokens
8
- - Special generation parameters
9
- """
10
 
11
  def __init__(self, config):
12
  super().__init__(config)
13
- self.cultural_token_id = 250001 # Should match tokenizer
14
-
15
  def generate(self, *args, **kwargs):
16
- """Override generate method with cultural preservation"""
17
- if kwargs.pop("cultural_preservation", False):
 
 
 
18
  kwargs["forced_bos_token_id"] = self.get_cultural_token()
19
- kwargs["max_length"] = kwargs.get("max_length", 200) # Longer for context
 
20
  return super().generate(*args, **kwargs)
21
 
22
  def get_cultural_token(self):
23
- """Get the special cultural context token ID"""
24
  return self.cultural_token_id
 
1
  from transformers import M2M100ForConditionalGeneration
2
 
3
  class CulturalM2M100(M2M100ForConditionalGeneration):
4
+ """Custom model with cultural preservation features"""
 
 
 
 
 
5
 
6
  def __init__(self, config):
7
  super().__init__(config)
8
+ self.cultural_token_id = 250001 # Must match tokenizer
9
+
10
  def generate(self, *args, **kwargs):
11
+ """Override generation with cultural context"""
12
+ cultural_preservation = kwargs.pop("cultural_preservation", False)
13
+
14
+ if cultural_preservation:
15
+ # Force cultural context token
16
  kwargs["forced_bos_token_id"] = self.get_cultural_token()
17
+ kwargs["max_length"] = kwargs.get("max_length", 512) + 50
18
+
19
  return super().generate(*args, **kwargs)
20
 
21
  def get_cultural_token(self):
22
+ """Get cultural preservation token ID"""
23
  return self.cultural_token_id