fix-version-problems
#2
by
moritzknaust - opened
- modeling_prismatic.py +4 -3
modeling_prismatic.py
CHANGED
|
@@ -41,7 +41,7 @@ IGNORE_INDEX = -100
|
|
| 41 |
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
| 42 |
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 43 |
result = fn(*args, **kwargs)
|
| 44 |
-
return result[0] if isinstance(result, tuple) else result
|
| 45 |
|
| 46 |
return wrapper
|
| 47 |
|
|
@@ -207,7 +207,8 @@ class PrismaticPreTrainedModel(PreTrainedModel):
|
|
| 207 |
@property
|
| 208 |
def _supports_sdpa(self) -> bool:
|
| 209 |
"""Check LLM supports SDPA Attention"""
|
| 210 |
-
return self.language_model._supports_sdpa
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
@@ -219,7 +220,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 219 |
raise ValueError("Missing config field `use_fused_vision_backbone`")
|
| 220 |
|
| 221 |
if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
|
| 222 |
-
|
| 223 |
"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
|
| 224 |
"if you urgently need support for latest TIMM versions."
|
| 225 |
)
|
|
|
|
| 41 |
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
| 42 |
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 43 |
result = fn(*args, **kwargs)
|
| 44 |
+
return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result
|
| 45 |
|
| 46 |
return wrapper
|
| 47 |
|
|
|
|
| 207 |
@property
|
| 208 |
def _supports_sdpa(self) -> bool:
|
| 209 |
"""Check LLM supports SDPA Attention"""
|
| 210 |
+
# TODO(moritzknaust): This is a hack to replace the original "return self.language_model._supports_sdpa"
|
| 211 |
+
return False
|
| 212 |
|
| 213 |
|
| 214 |
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
|
|
| 220 |
raise ValueError("Missing config field `use_fused_vision_backbone`")
|
| 221 |
|
| 222 |
if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
|
| 223 |
+
logger.warning(
|
| 224 |
"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
|
| 225 |
"if you urgently need support for latest TIMM versions."
|
| 226 |
)
|