Spaces:
Runtime error
Runtime error
| import os, glob, json, base64, re | |
| from io import BytesIO | |
| from PIL import Image, PngImagePlugin | |
| from image_process import image_canny,image_pose_mask,image_pose_mask_numpy | |
| from generate_img import generate_image, generate_image_sketch | |
| _SAVED_POSES_DIR = '' | |
| image_cache = dict() | |
| def set_save_dir(dir: str): | |
| global _SAVED_POSES_DIR | |
| _SAVED_POSES_DIR = os.path.realpath(str(dir)) | |
| def get_save_dir(): | |
| assert len(_SAVED_POSES_DIR) != 0 | |
| return _SAVED_POSES_DIR | |
| def get_saved_path(name: str): | |
| #return os.path.realpath(os.path.join(get_save_dir(), name)) | |
| return os.path.join(get_save_dir(), name) | |
| def atoi(text): | |
| return int(text) if text.isdigit() else text | |
| def natural_keys(text): | |
| return [ atoi(c) for c in re.split(r'(\d+)', text) ] | |
| def sorted_glob(path): | |
| return sorted(glob.glob(path), key=natural_keys) | |
| def name2path(name: str): | |
| if not isinstance(name, str): | |
| raise ValueError(f'str object expected, but {type(name)}') | |
| if len(name) == 0: | |
| raise ValueError(f'empty name') | |
| if '.' in name or '/' in name or '\\' in name: | |
| raise ValueError(f'invalid name: {name}') | |
| path = get_saved_path(f'{name}.png') | |
| if not path.startswith(get_save_dir()): | |
| raise ValueError(f'invalid name: {name}') | |
| return path | |
| def saved_poses(): | |
| for path in sorted_glob(os.path.join(get_save_dir(), '*.png')): | |
| yield Image.open(path) | |
| def all_poses(): | |
| for img in saved_poses(): | |
| buffer = BytesIO() | |
| img.save(buffer, format='png') | |
| if not hasattr(img, 'text'): | |
| continue | |
| pose_dict = { | |
| 'name': img.text['name'], # type: ignore | |
| 'image': base64.b64encode(buffer.getvalue()).decode('ascii'), | |
| 'screen': json.loads(img.text['screen']), # type: ignore | |
| 'camera': json.loads(img.text['camera']), # type: ignore | |
| 'joints': json.loads(img.text['joints']), # type: ignore | |
| } | |
| yield pose_dict | |
| def save_pose(data: dict): | |
| print(data) | |
| name = data['name'] | |
| screen = data['screen'] | |
| camera = data['camera'] | |
| joints = data['joints'] | |
| info = PngImagePlugin.PngInfo() | |
| info.add_text('name', name) | |
| info.add_text('screen', json.dumps(screen)) | |
| info.add_text('camera', json.dumps(camera)) | |
| info.add_text('joints', json.dumps(joints)) | |
| filepath = name2path(name) | |
| image = Image.open(BytesIO(base64.b64decode(data['image'][len('data:image/png;base64,'):]))) | |
| unit = max(image.width, image.height) | |
| mx, my = (unit - image.width) // 2, (unit - image.height) // 2 | |
| canvas = Image.new('RGB', (unit, unit), color=(68, 68, 68)) | |
| canvas.paste(image, (mx, my)) | |
| image = canvas.resize((canvas.width//4, canvas.height//4)) | |
| image.save(filepath, pnginfo=info) | |
| def delete_pose(name: str): | |
| filepath = name2path(name) | |
| os.remove(filepath) | |
| def load_pose(name: str): | |
| filepath = name2path(name) | |
| img = Image.open(filepath) | |
| buffer = BytesIO() | |
| img.save(buffer, format='png') | |
| if not hasattr(img, 'text'): | |
| raise ValueError(f'not pose data: {filepath}') | |
| pose_dict = { | |
| 'name': img.text['name'], # type: ignore | |
| 'image': base64.b64encode(buffer.getvalue()).decode('ascii'), | |
| 'screen': json.loads(img.text['screen']), # type: ignore | |
| 'camera': json.loads(img.text['camera']), # type: ignore | |
| 'joints': json.loads(img.text['joints']), # type: ignore | |
| } | |
| return pose_dict | |
| def base64_PIL(data:str): | |
| return Image.open(BytesIO(base64.b64decode(data))) | |
| def PIL_base64(data): | |
| return base64.b64encode(data.tobytes()).decode('utf-8') | |
| def resizeImg(image1,image2): | |
| width1, height1 = image1.size | |
| # 使用图像1的宽高来resize图像2 | |
| image2_resized = image2.resize((width1, height1)) | |
| # 返回resize后的图像2 | |
| return image2_resized | |
| # def get_img(data): | |
| # #执行逻辑 | |
| # if (data[0]): | |
| # bgImgBase64 = data[0]['bgImg'][len('data:image/png;base64,'):] | |
| # maskImgBase64 = data[0]['maskImg'][len('data:image/png;base64,'):] | |
| # image_cache['bgImgBase64'] = bgImgBase64 | |
| # image_cache['maskImgBase64'] = maskImgBase64 | |
| # return 'success' | |
| def generate_img(data, image_prompt, image_n_prompt): | |
| if (data[0]): | |
| bg_img = data[0]['bgImg'][len('data:image/png;base64,'):] | |
| mask_img_openpose = data[0]['maskImg'][len('data:image/png;base64,'):] | |
| print((len(bg_img), len(mask_img_openpose))) | |
| print((image_prompt, image_n_prompt)) | |
| maskImg_base64 = image_pose_mask(mask_img_openpose) | |
| controlnet_img_pil = base64_PIL(mask_img_openpose) | |
| bg_img_pil = base64_PIL(bg_img) | |
| mask_img_pil = base64_PIL(maskImg_base64) | |
| bg_img_pil = resizeImg(mask_img_pil, bg_img_pil) | |
| img = generate_image(image_prompt, image_n_prompt, controlnet_img_pil, bg_img_pil, mask_img_pil) | |
| return [img] | |
| # return [mask_img_pil] | |
| #openpose流程 | |
| return None | |
| def get_image_sketch(image, image_prompt, image_n_prompt): | |
| img_origin_numpy = image['image'] | |
| img_sketch_numpy = image['mask'] | |
| # print(type(img_origin)) | |
| # print(type(PIL_base64(Image.fromarray(img_masj)))) | |
| mask_pil = base64_PIL(image_pose_mask_numpy(img_sketch_numpy)) | |
| img_origin_pil = Image.fromarray(img_origin_numpy) | |
| sketch_pil = Image.fromarray(img_sketch_numpy) | |
| img = generate_image_sketch(image_prompt, image_n_prompt, sketch_pil, img_origin_pil, mask_pil) | |
| return img | |
| # return [mask_pil,img_origin_pil,Image.fromarray(img_masj)] |