Spaces:
Running
Running
Omar
commited on
Commit
·
98e130a
1
Parent(s):
ec02a3e
Upgrade transformers and add fallbacks
Browse files- backend/architecture_parser.py +52 -26
- backend/requirements.txt +1 -1
backend/architecture_parser.py
CHANGED
|
@@ -11,6 +11,14 @@ import torch
|
|
| 11 |
import torch.nn as nn
|
| 12 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def format_params(count: int) -> str:
|
| 16 |
"""Format parameter count in human-readable form."""
|
|
@@ -623,22 +631,31 @@ def load_model_for_inspection(model_id: str) -> Tuple[nn.Module, AutoConfig]:
|
|
| 623 |
model = None
|
| 624 |
errors = []
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
if
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
-
|
| 638 |
try:
|
| 639 |
-
model =
|
|
|
|
|
|
|
| 640 |
except Exception as e:
|
| 641 |
-
errors.append(f"
|
| 642 |
|
| 643 |
if model is None:
|
| 644 |
raise ValueError(f"Could not load model architecture. Errors: {errors}")
|
|
@@ -902,22 +919,31 @@ def load_model_from_config(config_dict: Dict[str, Any]) -> Tuple[nn.Module, Auto
|
|
| 902 |
model = None
|
| 903 |
errors = []
|
| 904 |
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
if
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
|
| 916 |
-
|
| 917 |
try:
|
| 918 |
-
model =
|
|
|
|
|
|
|
| 919 |
except Exception as e:
|
| 920 |
-
errors.append(f"
|
| 921 |
|
| 922 |
if model is None:
|
| 923 |
raise ValueError(f"Could not load model from config. Errors: {errors}")
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 13 |
|
| 14 |
+
# Monkeypatch for transformers import issues in some environment/model combinations
|
| 15 |
+
try:
|
| 16 |
+
import transformers.utils.import_utils as import_utils
|
| 17 |
+
if not hasattr(import_utils, "is_torch_fx_available"):
|
| 18 |
+
import_utils.is_torch_fx_available = lambda: False
|
| 19 |
+
except (ImportError, AttributeError):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
|
| 23 |
def format_params(count: int) -> str:
|
| 24 |
"""Format parameter count in human-readable form."""
|
|
|
|
| 631 |
model = None
|
| 632 |
errors = []
|
| 633 |
|
| 634 |
+
# Try to guess the model class from config
|
| 635 |
+
archs = getattr(config, "architectures", [])
|
| 636 |
+
is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
|
| 637 |
+
|
| 638 |
+
# Determine order of AutoModel classes to try
|
| 639 |
+
if is_encoder_decoder or any("Seq2Seq" in a or "ConditionalGeneration" in a for a in archs):
|
| 640 |
+
model_classes = [
|
| 641 |
+
(AutoModelForSeq2SeqLM, "Seq2SeqLM"),
|
| 642 |
+
(AutoModelForCausalLM, "CausalLM"),
|
| 643 |
+
(AutoModel, "AutoModel")
|
| 644 |
+
]
|
| 645 |
+
else:
|
| 646 |
+
model_classes = [
|
| 647 |
+
(AutoModelForCausalLM, "CausalLM"),
|
| 648 |
+
(AutoModel, "AutoModel"),
|
| 649 |
+
(AutoModelForSeq2SeqLM, "Seq2SeqLM")
|
| 650 |
+
]
|
| 651 |
|
| 652 |
+
for model_class, label in model_classes:
|
| 653 |
try:
|
| 654 |
+
model = model_class.from_config(config, trust_remote_code=True)
|
| 655 |
+
if model is not None:
|
| 656 |
+
break
|
| 657 |
except Exception as e:
|
| 658 |
+
errors.append(f"{label}: {e}")
|
| 659 |
|
| 660 |
if model is None:
|
| 661 |
raise ValueError(f"Could not load model architecture. Errors: {errors}")
|
|
|
|
| 919 |
model = None
|
| 920 |
errors = []
|
| 921 |
|
| 922 |
+
# Try to guess the model class from config
|
| 923 |
+
archs = getattr(config, "architectures", [])
|
| 924 |
+
is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
|
| 925 |
+
|
| 926 |
+
# Determine order of AutoModel classes to try
|
| 927 |
+
if is_encoder_decoder or any("Seq2Seq" in a or "ConditionalGeneration" in a for a in archs):
|
| 928 |
+
model_classes = [
|
| 929 |
+
(AutoModelForSeq2SeqLM, "Seq2SeqLM"),
|
| 930 |
+
(AutoModelForCausalLM, "CausalLM"),
|
| 931 |
+
(AutoModel, "AutoModel")
|
| 932 |
+
]
|
| 933 |
+
else:
|
| 934 |
+
model_classes = [
|
| 935 |
+
(AutoModelForCausalLM, "CausalLM"),
|
| 936 |
+
(AutoModel, "AutoModel"),
|
| 937 |
+
(AutoModelForSeq2SeqLM, "Seq2SeqLM")
|
| 938 |
+
]
|
| 939 |
|
| 940 |
+
for model_class, label in model_classes:
|
| 941 |
try:
|
| 942 |
+
model = model_class.from_config(config, trust_remote_code=True)
|
| 943 |
+
if model is not None:
|
| 944 |
+
break
|
| 945 |
except Exception as e:
|
| 946 |
+
errors.append(f"{label}: {e}")
|
| 947 |
|
| 948 |
if model is None:
|
| 949 |
raise ValueError(f"Could not load model from config. Errors: {errors}")
|
backend/requirements.txt
CHANGED
|
@@ -3,7 +3,7 @@ uvicorn[standard]==0.27.0
|
|
| 3 |
httpx==0.26.0
|
| 4 |
pydantic==2.5.3
|
| 5 |
python-multipart==0.0.6
|
| 6 |
-
transformers>=4.
|
| 7 |
torch>=2.0.0
|
| 8 |
accelerate>=0.25.0
|
| 9 |
huggingface_hub>=0.20.0
|
|
|
|
| 3 |
httpx==0.26.0
|
| 4 |
pydantic==2.5.3
|
| 5 |
python-multipart==0.0.6
|
| 6 |
+
transformers>=4.54.0
|
| 7 |
torch>=2.0.0
|
| 8 |
accelerate>=0.25.0
|
| 9 |
huggingface_hub>=0.20.0
|