Omar commited on
Commit
98e130a
·
1 Parent(s): ec02a3e

Upgrade transformers and add fallbacks

Browse files
backend/architecture_parser.py CHANGED
@@ -11,6 +11,14 @@ import torch
11
  import torch.nn as nn
12
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
13
 
 
 
 
 
 
 
 
 
14
 
15
  def format_params(count: int) -> str:
16
  """Format parameter count in human-readable form."""
@@ -623,22 +631,31 @@ def load_model_for_inspection(model_id: str) -> Tuple[nn.Module, AutoConfig]:
623
  model = None
624
  errors = []
625
 
626
- try:
627
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
628
- except Exception as e:
629
- errors.append(f"CausalLM: {e}")
630
-
631
- if model is None:
632
- try:
633
- model = AutoModelForSeq2SeqLM.from_config(config, trust_remote_code=True)
634
- except Exception as e:
635
- errors.append(f"Seq2SeqLM: {e}")
 
 
 
 
 
 
 
636
 
637
- if model is None:
638
  try:
639
- model = AutoModel.from_config(config, trust_remote_code=True)
 
 
640
  except Exception as e:
641
- errors.append(f"AutoModel: {e}")
642
 
643
  if model is None:
644
  raise ValueError(f"Could not load model architecture. Errors: {errors}")
@@ -902,22 +919,31 @@ def load_model_from_config(config_dict: Dict[str, Any]) -> Tuple[nn.Module, Auto
902
  model = None
903
  errors = []
904
 
905
- try:
906
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
907
- except Exception as e:
908
- errors.append(f"CausalLM: {e}")
909
-
910
- if model is None:
911
- try:
912
- model = AutoModelForSeq2SeqLM.from_config(config, trust_remote_code=True)
913
- except Exception as e:
914
- errors.append(f"Seq2SeqLM: {e}")
 
 
 
 
 
 
 
915
 
916
- if model is None:
917
  try:
918
- model = AutoModel.from_config(config, trust_remote_code=True)
 
 
919
  except Exception as e:
920
- errors.append(f"AutoModel: {e}")
921
 
922
  if model is None:
923
  raise ValueError(f"Could not load model from config. Errors: {errors}")
 
11
  import torch.nn as nn
12
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
13
 
14
+ # Monkeypatch for transformers import issues in some environment/model combinations
15
+ try:
16
+ import transformers.utils.import_utils as import_utils
17
+ if not hasattr(import_utils, "is_torch_fx_available"):
18
+ import_utils.is_torch_fx_available = lambda: False
19
+ except (ImportError, AttributeError):
20
+ pass
21
+
22
 
23
  def format_params(count: int) -> str:
24
  """Format parameter count in human-readable form."""
 
631
  model = None
632
  errors = []
633
 
634
+ # Try to guess the model class from config
635
+ archs = getattr(config, "architectures", [])
636
+ is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
637
+
638
+ # Determine order of AutoModel classes to try
639
+ if is_encoder_decoder or any("Seq2Seq" in a or "ConditionalGeneration" in a for a in archs):
640
+ model_classes = [
641
+ (AutoModelForSeq2SeqLM, "Seq2SeqLM"),
642
+ (AutoModelForCausalLM, "CausalLM"),
643
+ (AutoModel, "AutoModel")
644
+ ]
645
+ else:
646
+ model_classes = [
647
+ (AutoModelForCausalLM, "CausalLM"),
648
+ (AutoModel, "AutoModel"),
649
+ (AutoModelForSeq2SeqLM, "Seq2SeqLM")
650
+ ]
651
 
652
+ for model_class, label in model_classes:
653
  try:
654
+ model = model_class.from_config(config, trust_remote_code=True)
655
+ if model is not None:
656
+ break
657
  except Exception as e:
658
+ errors.append(f"{label}: {e}")
659
 
660
  if model is None:
661
  raise ValueError(f"Could not load model architecture. Errors: {errors}")
 
919
  model = None
920
  errors = []
921
 
922
+ # Try to guess the model class from config
923
+ archs = getattr(config, "architectures", [])
924
+ is_encoder_decoder = getattr(config, "is_encoder_decoder", False)
925
+
926
+ # Determine order of AutoModel classes to try
927
+ if is_encoder_decoder or any("Seq2Seq" in a or "ConditionalGeneration" in a for a in archs):
928
+ model_classes = [
929
+ (AutoModelForSeq2SeqLM, "Seq2SeqLM"),
930
+ (AutoModelForCausalLM, "CausalLM"),
931
+ (AutoModel, "AutoModel")
932
+ ]
933
+ else:
934
+ model_classes = [
935
+ (AutoModelForCausalLM, "CausalLM"),
936
+ (AutoModel, "AutoModel"),
937
+ (AutoModelForSeq2SeqLM, "Seq2SeqLM")
938
+ ]
939
 
940
+ for model_class, label in model_classes:
941
  try:
942
+ model = model_class.from_config(config, trust_remote_code=True)
943
+ if model is not None:
944
+ break
945
  except Exception as e:
946
+ errors.append(f"{label}: {e}")
947
 
948
  if model is None:
949
  raise ValueError(f"Could not load model from config. Errors: {errors}")
backend/requirements.txt CHANGED
@@ -3,7 +3,7 @@ uvicorn[standard]==0.27.0
3
  httpx==0.26.0
4
  pydantic==2.5.3
5
  python-multipart==0.0.6
6
- transformers>=4.36.0
7
  torch>=2.0.0
8
  accelerate>=0.25.0
9
  huggingface_hub>=0.20.0
 
3
  httpx==0.26.0
4
  pydantic==2.5.3
5
  python-multipart==0.0.6
6
+ transformers>=4.54.0
7
  torch>=2.0.0
8
  accelerate>=0.25.0
9
  huggingface_hub>=0.20.0