Spaces:
Sleeping
Sleeping
| import asyncio | |
| import hashlib | |
| import inspect | |
| import math | |
| import random | |
| import time | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from functools import partial, wraps | |
| from io import BytesIO | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Coroutine, | |
| List, | |
| Literal, | |
| Optional, | |
| Protocol, | |
| Tuple, | |
| TypeVar, | |
| ) | |
| import httpx | |
| from PIL.Image import Image as IMG | |
| from pil_utils import BuildImage, Text2Image | |
| from pil_utils.types import ColorType, FontStyle, FontWeight | |
| from typing_extensions import ParamSpec | |
| from .config import meme_config | |
| from .exception import MemeGeneratorException | |
| if TYPE_CHECKING: | |
| from .meme import Meme | |
| P = ParamSpec("P") | |
| R = TypeVar("R") | |
| def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]: | |
| """一个用于包装 sync function 为 async function 的装饰器 | |
| 参数: | |
| call: 被装饰的同步函数 | |
| """ | |
| async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
| loop = asyncio.get_running_loop() | |
| pfunc = partial(call, *args, **kwargs) | |
| result = await loop.run_in_executor(None, pfunc) | |
| return result | |
| return _wrapper | |
| def is_coroutine_callable(call: Callable[..., Any]) -> bool: | |
| """检查 call 是否是一个 callable 协程函数""" | |
| if inspect.isroutine(call): | |
| return inspect.iscoroutinefunction(call) | |
| if inspect.isclass(call): | |
| return False | |
| func_ = getattr(call, "__call__", None) | |
| return inspect.iscoroutinefunction(func_) | |
| def save_gif(frames: List[IMG], duration: float) -> BytesIO: | |
| output = BytesIO() | |
| frames[0].save( | |
| output, | |
| format="GIF", | |
| save_all=True, | |
| append_images=frames[1:], | |
| duration=duration * 1000, | |
| loop=0, | |
| disposal=2, | |
| optimize=False, | |
| ) | |
| # 没有超出最大大小,直接返回 | |
| nbytes = output.getbuffer().nbytes | |
| if nbytes <= meme_config.gif.gif_max_size * 10**6: | |
| return output | |
| # 超出最大大小,帧数超出最大帧数时,缩减帧数 | |
| n_frames = len(frames) | |
| gif_max_frames = meme_config.gif.gif_max_frames | |
| if n_frames > gif_max_frames: | |
| index = range(n_frames) | |
| ratio = n_frames / gif_max_frames | |
| index = (int(i * ratio) for i in range(gif_max_frames)) | |
| new_duration = duration * ratio | |
| new_frames = [frames[i] for i in index] | |
| return save_gif(new_frames, new_duration) | |
| # 超出最大大小,帧数没有超出最大帧数时,缩小尺寸 | |
| new_frames = [ | |
| frame.resize((int(frame.width * 0.9), int(frame.height * 0.9))) | |
| for frame in frames | |
| ] | |
| return save_gif(new_frames, duration) | |
| class Maker(Protocol): | |
| def __call__(self, img: BuildImage) -> BuildImage: | |
| ... | |
| class GifMaker(Protocol): | |
| def __call__(self, i: int) -> Maker: | |
| ... | |
| def get_avg_duration(image: IMG) -> float: | |
| if not getattr(image, "is_animated", False): | |
| return 0 | |
| total_duration = 0 | |
| for i in range(image.n_frames): | |
| image.seek(i) | |
| total_duration += image.info["duration"] | |
| return total_duration / image.n_frames | |
| def split_gif(image: IMG) -> List[IMG]: | |
| frames: List[IMG] = [] | |
| update_mode = "full" | |
| for i in range(image.n_frames): | |
| image.seek(i) | |
| if image.tile: # type: ignore | |
| update_region = image.tile[0][1][2:] # type: ignore | |
| if update_region != image.size: | |
| update_mode = "partial" | |
| break | |
| last_frame: Optional[IMG] = None | |
| for i in range(image.n_frames): | |
| image.seek(i) | |
| frame = image.copy() | |
| if update_mode == "partial" and last_frame: | |
| frame = last_frame.copy().paste(frame) | |
| frames.append(frame) | |
| image.seek(0) | |
| if image.info.__contains__("transparency"): | |
| frames[0].info["transparency"] = image.info["transparency"] | |
| return frames | |
| def make_jpg_or_gif( | |
| img: BuildImage, func: Maker, keep_transparency: bool = False | |
| ) -> BytesIO: | |
| """ | |
| 制作静图或者动图 | |
| :params | |
| * ``img``: 输入图片 | |
| * ``func``: 图片处理函数,输入img,返回处理后的图片 | |
| * ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
| """ | |
| image = img.image | |
| if not getattr(image, "is_animated", False): | |
| return func(img).save_jpg() | |
| else: | |
| frames = split_gif(image) | |
| duration = get_avg_duration(image) / 1000 | |
| frames = [func(BuildImage(frame)).image for frame in frames] | |
| if keep_transparency: | |
| image.seek(0) | |
| if image.info.__contains__("transparency"): | |
| frames[0].info["transparency"] = image.info["transparency"] | |
| return save_gif(frames, duration) | |
| def make_png_or_gif( | |
| img: BuildImage, func: Maker, keep_transparency: bool = False | |
| ) -> BytesIO: | |
| """ | |
| 制作静图或者动图 | |
| :params | |
| * ``img``: 输入图片 | |
| * ``func``: 图片处理函数,输入img,返回处理后的图片 | |
| * ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
| """ | |
| image = img.image | |
| if not getattr(image, "is_animated", False): | |
| return func(img).save_png() | |
| else: | |
| frames = split_gif(image) | |
| duration = get_avg_duration(image) / 1000 | |
| frames = [func(BuildImage(frame)).image for frame in frames] | |
| if keep_transparency: | |
| image.seek(0) | |
| if image.info.__contains__("transparency"): | |
| frames[0].info["transparency"] = image.info["transparency"] | |
| return save_gif(frames, duration) | |
| class FrameAlignPolicy(Enum): | |
| """ | |
| 要叠加的gif长度大于基准gif时,是否延长基准gif长度以对齐两个gif | |
| """ | |
| no_extend = 0 | |
| """不延长""" | |
| extend_first = 1 | |
| """延长第一帧""" | |
| extend_last = 2 | |
| """延长最后一帧""" | |
| extend_loop = 3 | |
| """以循环方式延长""" | |
| def make_gif_or_combined_gif( | |
| img: BuildImage, | |
| maker: GifMaker, | |
| frame_num: int, | |
| duration: float, | |
| frame_align: FrameAlignPolicy = FrameAlignPolicy.no_extend, | |
| input_based: bool = False, | |
| keep_transparency: bool = False, | |
| ) -> BytesIO: | |
| """ | |
| 使用静图或动图制作gif | |
| :params | |
| * ``img``: 输入图片,如头像 | |
| * ``maker``: 图片处理函数生成,传入第几帧,返回对应的图片处理函数 | |
| * ``frame_num``: 目标gif的帧数 | |
| * ``duration``: 相邻帧之间的时间间隔,单位为秒 | |
| * ``frame_align``: 要叠加的gif长度大于基准gif时,gif长度对齐方式 | |
| * ``input_based``: 是否以输入gif为基准合成gif,默认为`False`,即以目标gif为基准 | |
| * ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
| """ | |
| image = img.image | |
| if not getattr(image, "is_animated", False): | |
| return save_gif([maker(i)(img).image for i in range(frame_num)], duration) | |
| frame_num_in = image.n_frames | |
| duration_in = get_avg_duration(image) / 1000 | |
| total_duration_in = frame_num_in * duration_in | |
| total_duration = frame_num * duration | |
| if input_based: | |
| frame_num_base = frame_num_in | |
| frame_num_fit = frame_num | |
| duration_base = duration_in | |
| duration_fit = duration | |
| total_duration_base = total_duration_in | |
| total_duration_fit = total_duration | |
| else: | |
| frame_num_base = frame_num | |
| frame_num_fit = frame_num_in | |
| duration_base = duration | |
| duration_fit = duration_in | |
| total_duration_base = total_duration | |
| total_duration_fit = total_duration_in | |
| frame_idxs: List[int] = list(range(frame_num_base)) | |
| diff_duration = total_duration_fit - total_duration_base | |
| diff_num = int(diff_duration / duration_base) | |
| if diff_duration >= duration_base: | |
| if frame_align == FrameAlignPolicy.extend_first: | |
| frame_idxs = [0] * diff_num + frame_idxs | |
| elif frame_align == FrameAlignPolicy.extend_last: | |
| frame_idxs += [frame_num_base - 1] * diff_num | |
| elif frame_align == FrameAlignPolicy.extend_loop: | |
| frame_num_total = frame_num_base | |
| # 重复基准gif,直到两个gif总时长之差在1个间隔以内,或总帧数超出最大帧数 | |
| while frame_num_total + frame_num_base <= meme_config.gif.gif_max_frames: | |
| frame_num_total += frame_num_base | |
| frame_idxs += list(range(frame_num_base)) | |
| multiple = round(frame_num_total * duration_base / total_duration_fit) | |
| if ( | |
| math.fabs( | |
| total_duration_fit * multiple - frame_num_total * duration_base | |
| ) | |
| <= duration_base | |
| ): | |
| break | |
| frames: List[IMG] = [] | |
| frame_idx_fit = 0 | |
| time_start = 0 | |
| for i, idx in enumerate(frame_idxs): | |
| while frame_idx_fit < frame_num_fit: | |
| if ( | |
| frame_idx_fit * duration_fit | |
| <= i * duration_base - time_start | |
| < (frame_idx_fit + 1) * duration_fit | |
| ): | |
| if input_based: | |
| idx_in = idx | |
| idx_maker = frame_idx_fit | |
| else: | |
| idx_in = frame_idx_fit | |
| idx_maker = idx | |
| func = maker(idx_maker) | |
| image.seek(idx_in) | |
| frames.append(func(BuildImage(image.copy())).image) | |
| break | |
| else: | |
| frame_idx_fit += 1 | |
| if frame_idx_fit >= frame_num_fit: | |
| frame_idx_fit = 0 | |
| time_start += total_duration_fit | |
| if keep_transparency: | |
| image.seek(0) | |
| if image.info.__contains__("transparency"): | |
| frames[0].info["transparency"] = image.info["transparency"] | |
| return save_gif(frames, duration) | |
| async def translate(text: str, lang_from: str = "auto", lang_to: str = "zh") -> str: | |
| appid = meme_config.translate.baidu_trans_appid | |
| apikey = meme_config.translate.baidu_trans_apikey | |
| if not appid or not apikey: | |
| raise MemeGeneratorException( | |
| "The `baidu_trans_appid` or `baidu_trans_apikey` is not set." | |
| "Please check your config file!" | |
| ) | |
| salt = str(round(time.time() * 1000)) | |
| sign_raw = appid + text + salt + apikey | |
| sign = hashlib.md5(sign_raw.encode("utf8")).hexdigest() | |
| params = { | |
| "q": text, | |
| "from": lang_from, | |
| "to": lang_to, | |
| "appid": appid, | |
| "salt": salt, | |
| "sign": sign, | |
| } | |
| url = "https://fanyi-api.baidu.com/api/trans/vip/translate" | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(url, params=params) | |
| result = resp.json() | |
| return result["trans_result"][0]["dst"] | |
| async def translate_microsoft(text: str, lang_from: str = "zh-CN", lang_to: str = "ja") -> str: | |
| if lang_to == 'jp': | |
| lang_to = 'ja' | |
| params = { | |
| "text": text, | |
| "toLang": lang_to, | |
| } | |
| url = "http://translate.ikechan8370.com/translate" | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(url, params=params) | |
| result = resp.json() | |
| return result["translation"]["translation"] | |
| def random_text() -> str: | |
| return random.choice(["刘一", "陈二", "张三", "李四", "王五", "赵六", "孙七", "周八", "吴九", "郑十"]) | |
| def random_image() -> BytesIO: | |
| text = random.choice(["😂", "😅", "🤗", "🤤", "🥵", "🥰", "😍", "😭", "😋", "😏"]) | |
| return ( | |
| BuildImage.new("RGBA", (500, 500), "white") | |
| .draw_text((0, 0, 500, 500), text, max_fontsize=400) | |
| .save_png() | |
| ) | |
| class TextProperties: | |
| fill: ColorType = "black" | |
| style: FontStyle = "normal" | |
| weight: FontWeight = "normal" | |
| stroke_width: int = 0 | |
| stroke_fill: Optional[ColorType] = None | |
| def default_template(meme: "Meme", number: int) -> str: | |
| return f"{number}. {'/'.join(meme.keywords)}" | |
| def render_meme_list( | |
| meme_list: List[Tuple["Meme", TextProperties]], | |
| *, | |
| template: Callable[["Meme", int], str] = default_template, | |
| order_direction: Literal["row", "column"] = "column", | |
| columns: int = 4, | |
| column_align: Literal["left", "center", "right"] = "left", | |
| item_padding: Tuple[int, int] = (15, 6), | |
| image_padding: Tuple[int, int] = (50, 50), | |
| bg_color: ColorType = "white", | |
| fontsize: int = 30, | |
| fontname: str = "", | |
| fallback_fonts: List[str] = [], | |
| ) -> BytesIO: | |
| item_images: List[Text2Image] = [] | |
| for i, (meme, properties) in enumerate(meme_list, start=1): | |
| text = template(meme, i) | |
| t2m = Text2Image.from_text( | |
| text, | |
| fontsize=fontsize, | |
| style=properties.style, | |
| weight=properties.weight, | |
| fill=properties.fill, | |
| stroke_width=properties.stroke_width, | |
| stroke_fill=properties.stroke_fill, | |
| fontname=fontname, | |
| fallback_fonts=fallback_fonts, | |
| ) | |
| item_images.append(t2m) | |
| char_A = ( | |
| Text2Image.from_text( | |
| "A", fontsize=fontsize, fontname=fontname, fallback_fonts=fallback_fonts | |
| ) | |
| .lines[0] | |
| .chars[0] | |
| ) | |
| num_per_col = math.ceil(len(item_images) / columns) | |
| column_images: List[BuildImage] = [] | |
| for col in range(columns): | |
| if order_direction == "column": | |
| images = item_images[col * num_per_col : (col + 1) * num_per_col] | |
| else: | |
| images = [ | |
| item_images[num * columns + col] | |
| for num in range((len(item_images) - col - 1) // columns + 1) | |
| ] | |
| img_w = max((t2m.width for t2m in images)) + item_padding[0] * 2 | |
| img_h = (char_A.ascent + item_padding[1] * 2) * len(images) + char_A.descent | |
| image = BuildImage.new("RGB", (img_w, img_h), bg_color) | |
| y = item_padding[1] | |
| for t2m in images: | |
| if column_align == "left": | |
| x = 0 | |
| elif column_align == "center": | |
| x = (img_w - t2m.width - item_padding[0] * 2) // 2 | |
| else: | |
| x = img_w - t2m.width - item_padding[0] * 2 | |
| t2m.draw_on_image(image.image, (x, y)) | |
| y += char_A.ascent + item_padding[1] * 2 | |
| column_images.append(image) | |
| img_w = sum((img.width for img in column_images)) + image_padding[0] * 2 | |
| img_h = max((img.height for img in column_images)) + image_padding[1] * 2 | |
| image = BuildImage.new("RGB", (img_w, img_h), bg_color) | |
| x, y = image_padding | |
| for img in column_images: | |
| image.paste(img, (x, y)) | |
| x += img.width | |
| return image.save_jpg() | |