ktrk115 commited on
Commit
248b260
·
verified ·
1 Parent(s): 74a5657

Update modeling_markupdm.py

Browse files
Files changed (1) hide show
  1. modeling_markupdm.py +3 -3
modeling_markupdm.py CHANGED
@@ -105,12 +105,12 @@ class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore
105
  def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM":
106
  assert "config" in kwargs, "Config must be provided"
107
  config = kwargs["config"]
108
- torch_dtype = kwargs.get("torch_dtype", None)
109
 
110
  # Initialize text model
111
  text_model = AutoModelForCausalLM.from_config(
112
  config.text_model,
113
- torch_dtype=torch_dtype,
114
  attn_implementation=config._attn_implementation,
115
  )
116
 
@@ -119,7 +119,7 @@ class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore
119
  vision_model = AutoModel.from_config(
120
  config.vision_model,
121
  trust_remote_code=True,
122
- torch_dtype=torch_dtype,
123
  )
124
 
125
  return super().from_pretrained( # type: ignore
 
105
  def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM":
106
  assert "config" in kwargs, "Config must be provided"
107
  config = kwargs["config"]
108
+ dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None))
109
 
110
  # Initialize text model
111
  text_model = AutoModelForCausalLM.from_config(
112
  config.text_model,
113
+ dtype=dtype,
114
  attn_implementation=config._attn_implementation,
115
  )
116
 
 
119
  vision_model = AutoModel.from_config(
120
  config.vision_model,
121
  trust_remote_code=True,
122
+ dtype=dtype,
123
  )
124
 
125
  return super().from_pretrained( # type: ignore