Ldrago116 commited on
Commit
b8f92c8
·
verified ·
1 Parent(s): dd2fd28

Upload 4 files

Browse files
Files changed (3) hide show
  1. README.md +4 -4
  2. test.py +30 -0
  3. utils.py +101 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: WlSL
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
 
1
  ---
2
+ title: WSL
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from utils import classfication
4
+ source_video = "videos/main.mp4"
5
+ def video_identity(video):
6
+ print(type(video))
7
+ try:
8
+ os.makedirs(os.path.dirname(source_video ), exist_ok=True)
9
+ os.replace(video, source_video )
10
+ except Exception as e:
11
+ print(f"Error: {e}")
12
+ predection =classfication()
13
+ files = os.listdir(source_video)
14
+
15
+ # Iterate through each file and delete
16
+ for file in files:
17
+ file_path = os.path.join(source_video, file)
18
+ if os.path.isfile(file_path):
19
+ os.remove(file_path)
20
+
21
+ return predection
22
+
23
+
24
+ demo = gr.Interface(video_identity,
25
+ gr.Video(),
26
+ outputs="text"
27
+ )
28
+
29
+ if __name__ == "__main__":
30
+ demo.launch(share=True)
utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flash
2
+ from flash.core.data.utils import download_data
3
+ from flash.video import VideoClassificationData, VideoClassifier
4
+ import torch
5
+ from flash.video.classification.input_transform import VideoClassificationInputTransform
6
+ from pytorchvideo.transforms import (
7
+ ApplyTransformToKey,
8
+ ShortSideScale,
9
+ UniformTemporalSubsample,
10
+ UniformCropVideo,
11
+ )
12
+ from dataclasses import dataclass
13
+ from typing import Callable
14
+
15
+ import torch
16
+ from torch import Tensor
17
+
18
+ from flash.core.data.io.input import DataKeys
19
+ from flash.core.data.io.input_transform import InputTransform
20
+ from flash.core.data.transforms import ApplyToKeys
21
+ from flash.core.utilities.imports import (
22
+ _KORNIA_AVAILABLE,
23
+ _PYTORCHVIDEO_AVAILABLE,
24
+ requires,
25
+ )
26
+ from torchvision.transforms import Compose, CenterCrop
27
+ from torchvision.transforms import RandomCrop
28
+ from torch import nn
29
+ import kornia.augmentation as K
30
+ from torchvision import transforms as T
31
+ torch.set_float32_matmul_precision('high')
32
+
33
+ def normalize(x: Tensor) -> Tensor:
34
+ return x / 255.0
35
+
36
+ class TransformDataModule(InputTransform):
37
+ image_size: int = 256
38
+ temporal_sub_sample: int = 16 # This is the only change in our custom transform
39
+ mean: Tensor = torch.tensor([0.45, 0.45, 0.45])
40
+ std: Tensor = torch.tensor([0.225, 0.225, 0.225])
41
+ data_format: str = "BCTHW"
42
+ same_on_frame: bool = False
43
+
44
+ def per_sample_transform(self) -> Callable:
45
+ per_sample_transform = [CenterCrop(self.image_size)]
46
+
47
+ return Compose(
48
+ [
49
+ ApplyToKeys(
50
+ DataKeys.INPUT,
51
+ Compose(
52
+ [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
53
+ + per_sample_transform
54
+ ),
55
+ ),
56
+ ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
57
+ ]
58
+ )
59
+
60
+ def train_per_sample_transform(self) -> Callable:
61
+ per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)]
62
+
63
+ return Compose(
64
+ [
65
+ ApplyToKeys(
66
+ DataKeys.INPUT,
67
+ Compose(
68
+ [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
69
+ + per_sample_transform
70
+ ),
71
+ ),
72
+ ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
73
+ ]
74
+ )
75
+
76
+ def per_batch_transform_on_device(self) -> Callable:
77
+ return ApplyToKeys(
78
+ DataKeys.INPUT,
79
+ K.VideoSequential(
80
+ K.Normalize(self.mean, self.std),
81
+ data_format=self.data_format,
82
+ same_on_frame=self.same_on_frame,
83
+ ),
84
+ )
85
+
86
+
87
+
88
+ model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt")
89
+
90
+
91
+ datamodule_p = VideoClassificationData.from_folders(
92
+ predict_folder="videos",
93
+ batch_size=1,
94
+ transform=TransformDataModule()
95
+ )
96
+ trainer = flash.Trainer(
97
+ max_epochs=5,
98
+ )
99
+ def classfication():
100
+ predictions = trainer.predict(model, datamodule=datamodule_p, output="labels")
101
+ return predictions[0][0]