Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| import onnxruntime as ort | |
| from time import time | |
| from typing import Union | |
| from configs import ModelConfig, InferenceConfig | |
| from utils import ( | |
| POSE_BASED_MODELS, | |
| RGB_BASED_MODELS, | |
| HUGGINGFACE_RGB_BASED_MODELS, | |
| TORCHHUB_RGB_BASED_MODELS, | |
| ) | |
| from transformers import ( | |
| ImageProcessingMixin, | |
| FeatureExtractionMixin, | |
| AutoModelForVideoClassification, | |
| AutoModel, | |
| Pipeline, | |
| pipeline, | |
| ) | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from visualization import draw_text_on_image | |
| from utils import exists_on_hf | |
| from models import ( | |
| Swin3DConfig, Swin3DImageProcessor, Swin3DForVideoClassification, | |
| S3DConfig, S3DImageProcessor, S3DForVideoClassification, | |
| VideoResNetConfig, VideoResNetImageProcessor, VideoResNetForVideoClassification, | |
| MViTConfig, MViTImageProcessor, MViTForVideoClassification, | |
| SLGCNConfig, SLGCNFeatureExtractor, SLGCNForGraphClassification, | |
| SPOTERConfig, SPOTERFeatureExtractor, SPOTERForGraphClassification, | |
| DSTASLRConfig, DSTASLRFeatureExtractor, DSTASLRForGraphClassification, | |
| VideoMAEConfig, VideoMAEImageProcessor, VideoMAEForVideoClassification | |
| ) | |
| from pipelines import ( | |
| VideoClassificationPipeline, | |
| SLGCNGraphClassificationPipeline, | |
| SPOTERGraphClassificationPipeline, | |
| ) | |
| def load_model( | |
| model_config: ModelConfig, | |
| label2id: dict = None, | |
| id2label: dict = None, | |
| do_train: bool = False, | |
| ) -> tuple: | |
| ''' | |
| ''' | |
| if do_train: | |
| if model_config.arch in POSE_BASED_MODELS: | |
| return load_pose_model_for_training(model_config, label2id, id2label) | |
| return load_rgb_model_for_training(model_config, label2id, id2label) | |
| if model_config.arch in POSE_BASED_MODELS: | |
| processor = FeatureExtractionMixin.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| model = AutoModel.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| else: | |
| processor = ImageProcessingMixin.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| model = AutoModelForVideoClassification.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| model.eval() | |
| return model.config, processor, model | |
| def load_rgb_model_for_training( | |
| model_config: ModelConfig, | |
| label2id: dict = None, | |
| id2label: dict = None, | |
| ) -> tuple: | |
| ''' | |
| ''' | |
| if model_config.arch in HUGGINGFACE_RGB_BASED_MODELS: | |
| if model_config.arch == "videomae": | |
| config_class = VideoMAEConfig | |
| processor_class = VideoMAEImageProcessor | |
| model_class = VideoMAEForVideoClassification | |
| elif exists_on_hf(model_config.pretrained): | |
| processor = ImageProcessingMixin.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| model = AutoModelForVideoClassification.from_pretrained( | |
| model_config.pretrained, | |
| label2id, | |
| id2label, | |
| ignore_mismatched_sizes=True, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| return model.config, processor, model | |
| elif model_config.arch in TORCHHUB_RGB_BASED_MODELS: | |
| if model_config.arch in ['swin3d_t', 'swin3d_s', 'swin3d_b']: | |
| config_class = Swin3DConfig | |
| processor_class = Swin3DImageProcessor | |
| model_class = Swin3DForVideoClassification | |
| elif model_config.arch in ['r3d_18', 'mc3_18', 'r2plus1d_18']: | |
| config_class = VideoResNetConfig | |
| processor_class = VideoResNetImageProcessor | |
| model_class = VideoResNetForVideoClassification | |
| elif model_config.arch in ['s3d']: | |
| config_class = S3DConfig | |
| processor_class = S3DImageProcessor | |
| model_class = S3DForVideoClassification | |
| elif model_config.arch in ['mvit_v1_b', 'mvit_v2_s']: | |
| config_class = MViTConfig | |
| processor_class = MViTImageProcessor | |
| model_class = MViTForVideoClassification | |
| else: | |
| logging.error(f"Model {model_config.arch} is not supported") | |
| exit(1) | |
| config_class.register_for_auto_class() | |
| processor_class.register_for_auto_class("AutoImageProcessor") | |
| model_class.register_for_auto_class("AutoModel") | |
| model_class.register_for_auto_class("AutoModelForVideoClassification") | |
| logging.info(f"{model_config.arch} classes registered") | |
| config = config_class(**vars(model_config)) | |
| processor = processor_class(config=config) | |
| model = model_class(config=config, label2id=label2id, id2label=id2label) | |
| return config, processor, model | |
| def load_pose_model_for_training( | |
| model_config: ModelConfig, | |
| label2id: dict = None, | |
| id2label: dict = None, | |
| ) -> tuple: | |
| ''' | |
| ''' | |
| if exists_on_hf(model_config.pretrained): | |
| processor = FeatureExtractionMixin.from_pretrained( | |
| model_config.pretrained, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| model = AutoModel.from_pretrained( | |
| model_config.pretrained, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes=True, | |
| trust_remote_code=True, | |
| cache_dir="models/huggingface", | |
| ) | |
| return model.config, processor, model | |
| elif model_config.arch in POSE_BASED_MODELS: | |
| if model_config.arch == "spoter": | |
| config_class = SPOTERConfig | |
| processor_class = SPOTERFeatureExtractor | |
| model_class = SPOTERForGraphClassification | |
| elif model_config.arch == "sl_gcn": | |
| config_class = SLGCNConfig | |
| processor_class = SLGCNFeatureExtractor | |
| model_class = SLGCNForGraphClassification | |
| elif model_config.arch == "dsta_slr": | |
| config_class = DSTASLRConfig | |
| processor_class = DSTASLRFeatureExtractor | |
| model_class = DSTASLRForGraphClassification | |
| else: | |
| logging.error(f"Model {model_config.arch} is not supported") | |
| exit(1) | |
| config_class.register_for_auto_class() | |
| processor_class.register_for_auto_class("AutoFeatureExtractor") | |
| model_class.register_for_auto_class("AutoModel") | |
| logging.info(F"Registering {model_config.arch} classes") | |
| config = config_class(**vars(model_config)) | |
| processor = processor_class(config=config) | |
| model = model_class(config=config, label2id=label2id, id2label=id2label) | |
| return config, processor, model | |
| class Predictions: | |
| def __init__( | |
| self, | |
| predictions: list[dict] = None, | |
| inference_time: float = 0, | |
| start_time: float = 0, | |
| end_time: float = 0, | |
| ) -> None: | |
| self.predictions = predictions | |
| self.inference_time = inference_time | |
| self.start_time = start_time | |
| self.end_time = end_time | |
| def visualize( | |
| self, | |
| frame: torch.Tensor, | |
| position: tuple = (20, 100), | |
| prefix: str = "Predictions", | |
| color: tuple = (0, 0, 255), | |
| ) -> None: | |
| text = prefix + ": " + self.get_pred_message() | |
| return draw_text_on_image( | |
| image=frame, | |
| text=text, | |
| position=position, | |
| color=color, | |
| font_size=20, | |
| ) | |
| def get_pred_message(self) -> str: | |
| if not any(( | |
| self.start_time, | |
| self.end_time, | |
| self.inference_time, | |
| self.predictions | |
| )): | |
| return "" | |
| return ', '.join( | |
| [ | |
| f"{pred['gloss']} ({pred['score']*100:.2f}%)" | |
| for pred in self.predictions | |
| ] | |
| ) | |
| def __str__(self) -> str: | |
| if not any(( | |
| self.start_time, | |
| self.end_time, | |
| self.inference_time, | |
| self.predictions | |
| )): | |
| return "" | |
| predictions = self.get_pred_message() | |
| message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}" | |
| return message.format(self.start_time, self.end_time, self.inference_time, predictions) | |
| def merge_results(self, results: dict = None) -> dict: | |
| if results is None: | |
| results = { | |
| "start_time": [], | |
| "end_time": [], | |
| "inference_time": [], | |
| "prediction": [], | |
| } | |
| results["start_time"].append(self.start_time) | |
| results["end_time"].append(self.end_time) | |
| results["inference_time"].append(self.inference_time) | |
| results["prediction"].append(self.predictions) | |
| return results | |
| def get_predictions( | |
| inputs: torch.Tensor, | |
| model: Union[ort.InferenceSession, AutoModel], | |
| id2gloss: dict, | |
| k: int = 3, | |
| ) -> Predictions: | |
| ''' | |
| Get the top-k predictions. | |
| Parameters | |
| ---------- | |
| inputs : torch.Tensor | |
| Model inputs (Time, Height, Width, Channels). | |
| model : Union[ort.InferenceSession, AutoModel] | |
| Model to get predictions from. | |
| id2gloss : dict | |
| Mapping of class indices to glosses. | |
| k : int, optional | |
| Number of predictions to return, by default 3. | |
| Returns | |
| ------- | |
| tuple | |
| List of top-k predictions and inference time. | |
| ''' | |
| if inputs is None: | |
| return Predictions() | |
| # Get logits | |
| start_time = time() | |
| if isinstance(model, ort.InferenceSession): | |
| inputs = inputs.cpu().numpy() | |
| logits = torch.from_numpy(model.run(None, {"pixel_values": inputs})[0]) | |
| else: | |
| logits = model(inputs.to(model.device)).logits | |
| inference_time = time() - start_time | |
| # Get top-3 predictions | |
| topk_scores, topk_indices = torch.topk(logits, k, dim=1) | |
| topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() | |
| topk_indices = topk_indices.squeeze().detach().numpy() | |
| predictions = [ | |
| { | |
| 'gloss': id2gloss[str(topk_indices[i])], | |
| 'score': topk_scores[i], | |
| } | |
| for i in range(k) | |
| ] | |
| return Predictions(predictions=predictions, inference_time=inference_time) | |
| def register_pipeline(model_config: ModelConfig) -> Pipeline: | |
| ''' | |
| ''' | |
| _, processor, model = load_model(model_config) | |
| if model_config.arch == "spoter": | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "video-classification", | |
| pipeline_class=SPOTERGraphClassificationPipeline, | |
| pt_model=AutoModel, | |
| default={"pt": ("vsltranslation/spoter_v3.0", "main")}, | |
| type="multimodal", | |
| ) | |
| return SPOTERGraphClassificationPipeline( | |
| model=model, | |
| feature_extractor=processor, | |
| ) | |
| elif model_config.arch in ["sl_gcn", "dsta_slr"]: | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "video-classification", | |
| pipeline_class=SLGCNGraphClassificationPipeline, | |
| pt_model=AutoModel, | |
| default={"pt": ("vsltranslation/sl_gcn_joint_v1.0", "main")}, | |
| type="multimodal", | |
| ) | |
| return SLGCNGraphClassificationPipeline( | |
| model=model, | |
| feature_extractor=processor, | |
| ) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "video-classification", | |
| pipeline_class=VideoClassificationPipeline, | |
| pt_model=AutoModelForVideoClassification, | |
| default={"pt": ("vsltranslation/swin3d_t_v1.0", "main")}, | |
| type="multimodal", | |
| ) | |
| return VideoClassificationPipeline( | |
| model=model, | |
| image_processor=processor, | |
| ) | |
| def load_pipeline( | |
| model_config: ModelConfig, | |
| inference_config: InferenceConfig, | |
| ) -> Pipeline: | |
| ''' | |
| ''' | |
| if model_config.arch in POSE_BASED_MODELS: | |
| return pipeline( | |
| "video-classification", | |
| model=model_config.pretrained, | |
| feature_extractor=model_config.pretrained, | |
| device=inference_config.device, | |
| model_kwargs={ | |
| "cache_dir": inference_config.cache_dir, | |
| }, | |
| trust_remote_code=True, | |
| use_onnx=inference_config.use_onnx, | |
| top_k=inference_config.top_k, | |
| bone_stream=inference_config.bone_stream, | |
| motion_stream=inference_config.motion_stream, | |
| ) | |
| return pipeline( | |
| "video-classification", | |
| model=model_config.pretrained, | |
| image_processor=model_config.pretrained, | |
| device=inference_config.device, | |
| model_kwargs={ | |
| "cache_dir": inference_config.cache_dir, | |
| }, | |
| trust_remote_code=True, | |
| use_onnx=inference_config.use_onnx, | |
| top_k=inference_config.top_k, | |
| ) | |
| def get_input_shape( | |
| arch: str, | |
| processor: Union[ImageProcessingMixin, FeatureExtractionMixin], | |
| batch_size: int = 1, | |
| ) -> tuple: | |
| ''' | |
| Get the input shape for the model. | |
| Parameters | |
| ---------- | |
| processor : Union[ImageProcessingMixin, FeatureExtractionMixin] | |
| Model processor. | |
| batch_size : int, optional | |
| Batch size, by default 1. | |
| Returns | |
| ------- | |
| tuple | |
| Input shape. | |
| ''' | |
| if arch in RGB_BASED_MODELS: | |
| return ( | |
| batch_size, | |
| processor.num_frames, | |
| 3, | |
| processor.size["height"], | |
| processor.size["width"] | |
| ) | |
| elif arch in POSE_BASED_MODELS: | |
| if arch == "spoter": | |
| return ( | |
| batch_size, | |
| processor.num_frames, | |
| processor.num_points, | |
| processor.in_channels, | |
| ) | |
| elif arch in ["sl_gcn", "dsta_slr"]: | |
| return ( | |
| batch_size, | |
| processor.in_channels, | |
| processor.window_size, | |
| processor.num_points, | |
| processor.num_people, | |
| ) | |
| else: | |
| logging.error(f"Model {arch} is not supported") | |
| exit(1) | |
| else: | |
| logging.error(f"Model {arch} is not supported") | |
| exit(1) | |