weiyuyeh commited on
Commit
87b6625
·
1 Parent(s): 071b060

use multi config file

Browse files
app.py CHANGED
@@ -7,6 +7,8 @@ import sys
7
  target_paths = {
8
  "data": "/home/user/app/upload/data.zip",
9
  "data_dir": "/home/user/app/upload/data",
 
 
10
  "config": "/home/user/app/src/multiview_consist_edit/config/infer_tryon_multi.yaml",
11
  "output_data": "/home/user/app/image_output_tryon_mvhumannet",
12
  "output_zip": "/home/user/app/outputs/result.zip",
@@ -18,10 +20,17 @@ def unzip_data():
18
  shutil.rmtree(target_paths["data_dir"])
19
  os.makedirs(target_paths["data_dir"], exist_ok=True)
20
  shutil.unpack_archive(target_paths["data"], target_paths["data_dir"])
21
- return target_paths["data_dir"]
22
  else:
23
  raise FileNotFoundError("Data file not found at " + target_paths["data"])
24
-
 
 
 
 
 
 
 
25
 
26
  def zip_outputs():
27
  if os.path.exists(target_paths["output_zip"]):
@@ -30,20 +39,32 @@ def zip_outputs():
30
  return target_paths["output_zip"]
31
 
32
 
33
- def start_inference_stream():
34
- process = subprocess.Popen(
35
- ["python", "src/multiview_consist_edit/infer_tryon_multi.py"],
36
- stdout=subprocess.PIPE,
37
- stderr=subprocess.STDOUT,
38
- text=True,
39
- bufsize=1,
40
- universal_newlines=True
41
- )
42
-
43
- output = []
44
- for line in process.stdout:
45
- output.append(line)
46
- yield "".join(output)
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def install_package(package_name):
49
  try:
@@ -93,10 +114,10 @@ print("package version set complete")
93
 
94
  def save_files(data_file, config_file):
95
  os.makedirs(os.path.dirname(target_paths["data"]), exist_ok=True)
96
- os.makedirs(os.path.dirname(target_paths["config"]), exist_ok=True)
97
 
98
  shutil.copy(data_file.name, target_paths["data"])
99
- shutil.copy(config_file.name, target_paths["config"])
100
  unzip_data()
101
  return "檔案已成功上傳、儲存並解壓縮了!"
102
 
@@ -105,7 +126,9 @@ with gr.Blocks(theme=gr.themes.Origin()) as demo:
105
  gr.Markdown("## 請先上傳檔案")
106
  with gr.Row():
107
  data_input = gr.File(label="上傳資料壓縮檔", file_types=[".zip"])
108
- config_input = gr.File(label="Config 檔", file_types=[".yaml", ".yml"])
 
 
109
 
110
  upload_button = gr.Button("上傳並儲存")
111
  output = gr.Textbox(label="狀態")
@@ -116,36 +139,36 @@ with gr.Blocks(theme=gr.themes.Origin()) as demo:
116
  log_output = gr.Textbox(label="Inference Log", lines=20)
117
  infer_btn = gr.Button("Start Inference")
118
 
119
- gr.Markdown("## Pip Installer")
120
- with gr.Column():
121
- with gr.Row():
122
- pkg_input = gr.Textbox(lines=1, placeholder="輸入想安裝的套件名稱,例如 diffusers 或 numpy==1.2.0")
123
- install_output = gr.Textbox(label="Install Output", lines=10)
124
- install_btn = gr.Button("Install Package")
125
 
126
- gr.Markdown("## Pip Uninstaller")
127
- with gr.Column():
128
- with gr.Row():
129
- pkg_input2 = gr.Textbox(lines=1, placeholder="輸入想解除安裝的套件名稱,例如 diffusers 或 numpy")
130
- uninstall_output = gr.Textbox(label="Uninstall Output", lines=10)
131
- uninstall_btn = gr.Button("Uninstall Package")
132
 
133
- gr.Markdown("## Pip show")
134
- with gr.Column():
135
- with gr.Row():
136
- show_input = gr.Textbox(label="輸入套件名稱(如 diffusers)")
137
- show_output = gr.Textbox(label="套件資訊", lines=10)
138
- show_btn = gr.Button("pip show")
139
 
140
  gr.Markdown("## Download results")
141
  with gr.Column():
142
  file_output = gr.File(label="點擊下載", interactive=True)
143
  download_btn = gr.Button("下載結果")
144
 
145
- show_btn.click(fn=show_package, inputs=show_input, outputs=show_output)
146
  download_btn.click(fn=zip_outputs, outputs=file_output)
147
- install_btn.click(fn=install_package, inputs=pkg_input, outputs=install_output)
148
- infer_btn.click(fn=start_inference_stream, outputs=log_output)
149
- uninstall_btn.click(fn=uninstall_package, inputs=pkg_input2, outputs=uninstall_output)
150
  upload_button.click(fn=save_files,inputs=[data_input, config_input],outputs=output)
151
  demo.launch()
 
7
  target_paths = {
8
  "data": "/home/user/app/upload/data.zip",
9
  "data_dir": "/home/user/app/upload/data",
10
+ "config_zip": "/home/user/app/upload/config.zip",
11
+ "configs": "/home/user/app/upload/config",
12
  "config": "/home/user/app/src/multiview_consist_edit/config/infer_tryon_multi.yaml",
13
  "output_data": "/home/user/app/image_output_tryon_mvhumannet",
14
  "output_zip": "/home/user/app/outputs/result.zip",
 
20
  shutil.rmtree(target_paths["data_dir"])
21
  os.makedirs(target_paths["data_dir"], exist_ok=True)
22
  shutil.unpack_archive(target_paths["data"], target_paths["data_dir"])
23
+ # return target_paths["data_dir"]
24
  else:
25
  raise FileNotFoundError("Data file not found at " + target_paths["data"])
26
+ if os.path.exists(target_paths["config_zip"]):
27
+ if os.path.exists(target_paths["configs"]):
28
+ shutil.rmtree(target_paths["configs"])
29
+ os.makedirs(target_paths["configs"], exist_ok=True)
30
+ shutil.unpack_archive(target_paths["config_zip"], target_paths["configs"])
31
+ # return target_paths["configs"]
32
+ else:
33
+ raise FileNotFoundError("Config file not found at " + target_paths["config_zip"])
34
 
35
  def zip_outputs():
36
  if os.path.exists(target_paths["output_zip"]):
 
39
  return target_paths["output_zip"]
40
 
41
 
42
+ def start_inference_stream(config_count):
43
+ config_dir = target_paths["configs"]
44
+ config_files = sorted([
45
+ f for f in os.listdir(config_dir)
46
+ if f.endswith(".yaml") or f.endswith(".yml")
47
+ ])
48
+ if config_count < len(config_files):
49
+ config_files = config_files[:config_count]
50
+
51
+ for cfg in config_files:
52
+ src_path = os.path.join(config_dir, cfg)
53
+ shutil.copy(src_path, target_paths["config"])
54
+
55
+ process = subprocess.Popen(
56
+ ["python", "src/multiview_consist_edit/infer_tryon_multi.py"],
57
+ stdout=subprocess.PIPE,
58
+ stderr=subprocess.STDOUT,
59
+ text=True,
60
+ bufsize=1,
61
+ universal_newlines=True
62
+ )
63
+
64
+ output = []
65
+ for line in process.stdout:
66
+ output.append(line)
67
+ yield "".join(output)
68
 
69
  def install_package(package_name):
70
  try:
 
114
 
115
  def save_files(data_file, config_file):
116
  os.makedirs(os.path.dirname(target_paths["data"]), exist_ok=True)
117
+ os.makedirs(os.path.dirname(target_paths["config_zip"]), exist_ok=True)
118
 
119
  shutil.copy(data_file.name, target_paths["data"])
120
+ shutil.copy(config_file.name, target_paths["config_zip"])
121
  unzip_data()
122
  return "檔案已成功上傳、儲存並解壓縮了!"
123
 
 
126
  gr.Markdown("## 請先上傳檔案")
127
  with gr.Row():
128
  data_input = gr.File(label="上傳資料壓縮檔", file_types=[".zip"])
129
+ with gr.Column():
130
+ config_input = gr.File(label="Config 壓縮檔", file_types=[".zip"])
131
+ config_count = gr.Number(label="Config 總數", precision=0, value=1)
132
 
133
  upload_button = gr.Button("上傳並儲存")
134
  output = gr.Textbox(label="狀態")
 
139
  log_output = gr.Textbox(label="Inference Log", lines=20)
140
  infer_btn = gr.Button("Start Inference")
141
 
142
+ # gr.Markdown("## Pip Installer")
143
+ # with gr.Column():
144
+ # with gr.Row():
145
+ # pkg_input = gr.Textbox(lines=1, placeholder="輸入想安裝的套件名稱,例如 diffusers 或 numpy==1.2.0")
146
+ # install_output = gr.Textbox(label="Install Output", lines=10)
147
+ # install_btn = gr.Button("Install Package")
148
 
149
+ # gr.Markdown("## Pip Uninstaller")
150
+ # with gr.Column():
151
+ # with gr.Row():
152
+ # pkg_input2 = gr.Textbox(lines=1, placeholder="輸入想解除安裝的套件名稱,例如 diffusers 或 numpy")
153
+ # uninstall_output = gr.Textbox(label="Uninstall Output", lines=10)
154
+ # uninstall_btn = gr.Button("Uninstall Package")
155
 
156
+ # gr.Markdown("## Pip show")
157
+ # with gr.Column():
158
+ # with gr.Row():
159
+ # show_input = gr.Textbox(label="輸入套件名稱(如 diffusers)")
160
+ # show_output = gr.Textbox(label="套件資訊", lines=10)
161
+ # show_btn = gr.Button("pip show")
162
 
163
  gr.Markdown("## Download results")
164
  with gr.Column():
165
  file_output = gr.File(label="點擊下載", interactive=True)
166
  download_btn = gr.Button("下載結果")
167
 
168
+ # show_btn.click(fn=show_package, inputs=show_input, outputs=show_output)
169
  download_btn.click(fn=zip_outputs, outputs=file_output)
170
+ # install_btn.click(fn=install_package, inputs=pkg_input, outputs=install_output)
171
+ infer_btn.click(fn=start_inference_stream,input=config_count, outputs=log_output)
172
+ # uninstall_btn.click(fn=uninstall_package, inputs=pkg_input2, outputs=uninstall_output)
173
  upload_button.click(fn=save_files,inputs=[data_input, config_input],outputs=output)
174
  demo.launch()
src/multiview_consist_edit/config/infer_tryon_multi_T.yaml CHANGED
@@ -40,7 +40,10 @@ infer_data:
40
  clip_model_path: 'openai/clip-vit-base-patch32'
41
  is_train: false
42
  mode: 'pair'
43
- output_front: false
 
 
 
44
  front_id: 0 # front view id, used for thuman2 dataset
45
  is_use_all_views: false # if your all views count <= 16, you can set it to true, otherwise false
46
 
 
40
  clip_model_path: 'openai/clip-vit-base-patch32'
41
  is_train: false
42
  mode: 'pair'
43
+ # the view ids you want to use
44
+ # it will cover other settings like output_front, front_id, is_use_all_views
45
+ view_ids: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
46
+ output_front: true
47
  front_id: 0 # front view id, used for thuman2 dataset
48
  is_use_all_views: false # if your all views count <= 16, you can set it to true, otherwise false
49
 
src/multiview_consist_edit/data/Thuman2_multi.py CHANGED
@@ -48,7 +48,7 @@ def crop_image(human_img_orig):
48
 
49
  class Thuman2_Dataset(Dataset):
50
  def __init__(
51
- self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8, output_front=True,front_id=0, is_use_all_views=False
52
  ):
53
  c_names_front = []
54
  c_names_back = []
@@ -81,8 +81,10 @@ class Thuman2_Dataset(Dataset):
81
  self.is_train = is_train
82
  self.sample_size = sample_size
83
  self.multi_length = multi_length
 
84
  self.front_id = front_id
85
  self.is_use_all_views = is_use_all_views
 
86
  self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=False)
87
 
88
  self.pixel_transforms = transforms.Compose([
@@ -117,7 +119,6 @@ class Thuman2_Dataset(Dataset):
117
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
118
  ])
119
  self.color_transform = transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.0)
120
- self.output_front = output_front
121
 
122
  def __len__(self):
123
  if len(self.cloth_ids) >= 1:
@@ -149,27 +150,30 @@ class Thuman2_Dataset(Dataset):
149
  if self.is_train:
150
  select_images = random.sample(images, self.multi_length)
151
  else:
152
- # select_idxs = [0,3,6,9,12, 15,18,21,24,27, 79,76,73,70,67,64]
153
- L = len(images)
154
- select_idxs = []
155
- if self.is_use_all_views:
156
- sl = L/2
157
- else:
158
- sl = 16
159
- if self.output_front:
160
- begin = 0
161
- while begin < L//4:
162
- select_idxs.append((int(begin) + self.front_id) % L)
163
- begin += L/2/sl
164
- begin = L*3//4
165
- while begin < L:
166
- select_idxs.append((int(begin) + self.front_id) % L)
167
- begin += L/2/sl
168
  else:
169
- begin = L//4
170
- while begin < L*3//4:
171
- select_idxs.append((int(begin) + self.front_id) % L)
172
- begin += L/2/sl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # print(sorted(select_idxs))
174
  # select_idxs = [0,3,6,9,12, 15,18,21,24,27, L-1,L-4,L-7,L-10,L-13,L-16]
175
  select_images = []
 
48
 
49
  class Thuman2_Dataset(Dataset):
50
  def __init__(
51
+ self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8, view_ids=None, output_front=True, front_id=0, is_use_all_views=False
52
  ):
53
  c_names_front = []
54
  c_names_back = []
 
81
  self.is_train = is_train
82
  self.sample_size = sample_size
83
  self.multi_length = multi_length
84
+ self.view_ids = view_ids
85
  self.front_id = front_id
86
  self.is_use_all_views = is_use_all_views
87
+ self.output_front = output_front
88
  self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=False)
89
 
90
  self.pixel_transforms = transforms.Compose([
 
119
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
120
  ])
121
  self.color_transform = transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.0)
 
122
 
123
  def __len__(self):
124
  if len(self.cloth_ids) >= 1:
 
150
  if self.is_train:
151
  select_images = random.sample(images, self.multi_length)
152
  else:
153
+ if self.view_ids is not None and len(self.view_ids) > 0:
154
+ select_idxs = self.view_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  else:
156
+ # select_idxs = [0,3,6,9,12, 15,18,21,24,27, 79,76,73,70,67,64]
157
+ L = len(images)
158
+ select_idxs = []
159
+ if self.is_use_all_views:
160
+ sl = L/2
161
+ else:
162
+ sl = 16
163
+ if self.output_front:
164
+ begin = 0
165
+ while begin < L//4:
166
+ select_idxs.append((int(begin) + self.front_id) % L)
167
+ begin += L/2/sl
168
+ begin = L*3//4
169
+ while begin < L:
170
+ select_idxs.append((int(begin) + self.front_id) % L)
171
+ begin += L/2/sl
172
+ else:
173
+ begin = L//4
174
+ while begin < L*3//4:
175
+ select_idxs.append((int(begin) + self.front_id) % L)
176
+ begin += L/2/sl
177
  # print(sorted(select_idxs))
178
  # select_idxs = [0,3,6,9,12, 15,18,21,24,27, L-1,L-4,L-7,L-10,L-13,L-16]
179
  select_images = []