Update modeling_markupdm.py
Browse files- modeling_markupdm.py +3 -3
modeling_markupdm.py
CHANGED
|
@@ -105,12 +105,12 @@ class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore
|
|
| 105 |
def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM":
|
| 106 |
assert "config" in kwargs, "Config must be provided"
|
| 107 |
config = kwargs["config"]
|
| 108 |
-
|
| 109 |
|
| 110 |
# Initialize text model
|
| 111 |
text_model = AutoModelForCausalLM.from_config(
|
| 112 |
config.text_model,
|
| 113 |
-
|
| 114 |
attn_implementation=config._attn_implementation,
|
| 115 |
)
|
| 116 |
|
|
@@ -119,7 +119,7 @@ class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore
|
|
| 119 |
vision_model = AutoModel.from_config(
|
| 120 |
config.vision_model,
|
| 121 |
trust_remote_code=True,
|
| 122 |
-
|
| 123 |
)
|
| 124 |
|
| 125 |
return super().from_pretrained( # type: ignore
|
|
|
|
| 105 |
def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM":
|
| 106 |
assert "config" in kwargs, "Config must be provided"
|
| 107 |
config = kwargs["config"]
|
| 108 |
+
dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None))
|
| 109 |
|
| 110 |
# Initialize text model
|
| 111 |
text_model = AutoModelForCausalLM.from_config(
|
| 112 |
config.text_model,
|
| 113 |
+
dtype=dtype,
|
| 114 |
attn_implementation=config._attn_implementation,
|
| 115 |
)
|
| 116 |
|
|
|
|
| 119 |
vision_model = AutoModel.from_config(
|
| 120 |
config.vision_model,
|
| 121 |
trust_remote_code=True,
|
| 122 |
+
dtype=dtype,
|
| 123 |
)
|
| 124 |
|
| 125 |
return super().from_pretrained( # type: ignore
|