Spaces:
Sleeping
Sleeping
| from .base import BaseVideoModel | |
| from packaging import version | |
| import torch | |
| from typing import Optional, Union, Dict | |
| # IMP: Add required versions here | |
| transformers_required_version = version.parse("5.0.0") | |
| # Conditional imports based on transformers version | |
| import transformers | |
| from transformers import BitsAndBytesConfig | |
| # Check transformers version | |
| transformers_version = version.parse(transformers.__version__) | |
| # transformers v5 condition | |
| if transformers_version >= transformers_required_version: | |
| from .qwen2_5vl import Qwen2_5VLModel | |
| from .qwen3vl import Qwen3VLModel | |
| from .internvl import InternVLModel | |
| from .llava_video import LLaVAVideoModel | |
| TRANSFORMERS_MODELS_AVAILABLE = True | |
| else: | |
| raise ValueError(f"Transformers v5 models require transformers>=5.0.0, but found {transformers.__version__}. Transformers v5 models will not be available. Please upgrade to transformers>=5.0.0 or switch conda environments to use Transformers v5 models.") | |
| # Function to get the model by mapping model ID to the correct model class | |
| def load_model( | |
| model_path: str, | |
| dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16, | |
| device_map: Optional[Union[str, Dict]] = "auto", | |
| attn_implementation: Optional[str] = "flash_attention_2", | |
| load_8bit: Optional[bool] = False, | |
| load_4bit: Optional[bool] = False, | |
| ) -> BaseVideoModel: | |
| if "LLaVA-Video" in model_path: | |
| return LLaVAVideoModel( | |
| model_path, | |
| dtype=dtype, | |
| device_map=device_map, | |
| attn_implementation=attn_implementation, | |
| load_8bit=load_8bit, | |
| load_4bit=load_4bit, | |
| ) | |
| elif "Qwen" in model_path: | |
| if "Qwen3" in model_path: | |
| return Qwen3VLModel( | |
| model_path, | |
| dtype=dtype, | |
| device_map=device_map, | |
| attn_implementation=attn_implementation, | |
| load_8bit=load_8bit, | |
| load_4bit=load_4bit, | |
| ) | |
| else: | |
| return Qwen2_5VLModel( | |
| model_path, | |
| dtype=dtype, | |
| device_map=device_map, | |
| attn_implementation=attn_implementation, | |
| load_8bit=load_8bit, | |
| load_4bit=load_4bit, | |
| ) | |
| elif "Intern" in model_path: | |
| return InternVLModel( | |
| model_path, | |
| dtype=dtype, | |
| device_map=device_map, | |
| attn_implementation=attn_implementation, | |
| load_8bit=load_8bit, | |
| load_4bit=load_4bit, | |
| ) | |