Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| import gradio as gr | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM | |
| from config import settings | |
| from modelscope.outputs import OutputKeys | |
| from modelscope.pipelines import pipeline | |
| from modelscope.utils.constant import Tasks | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| DPMSolverMultistepScheduler, | |
| DDIMScheduler, | |
| HeunDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| PNDMScheduler | |
| ) | |
| class KarrasDPM: | |
| def from_config(config): | |
| return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) | |
| SCHEDULERS = { | |
| "DDIM": DDIMScheduler, | |
| "DPMSolverMultistep": DPMSolverMultistepScheduler, | |
| "HeunDiscrete": HeunDiscreteScheduler, | |
| "KarrasDPM": KarrasDPM, | |
| "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler, | |
| "K_EULER": EulerDiscreteScheduler, | |
| "PNDM": PNDMScheduler, | |
| } | |
| deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key) | |
| open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key) | |
| tongyi_llm = TongYiLLM(api_key=settings.tongyi_api_key) | |
| def init_chat(): | |
| return deep_seek_llm.get_chat_engine() | |
| def predict(message, history, chat): | |
| if chat is None: | |
| chat = init_chat() | |
| history_messages = [] | |
| for human, assistant in history: | |
| history_messages.append(HumanMessage(content=human)) | |
| history_messages.append(AIMessage(content=assistant)) | |
| history_messages.append(HumanMessage(content=message.text)) | |
| response_message = '' | |
| for chunk in chat.stream(history_messages): | |
| response_message = response_message + chunk.content | |
| yield response_message | |
| def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int): | |
| print('?????', _provider, _chat, _model, _temperature, _max_tokens) | |
| if _provider == 'DeepSeek': | |
| _chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
| if _provider == 'OpenRouter': | |
| _chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
| if _provider == 'Tongyi': | |
| _chat = tongyi_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
| return _chat | |
| def object_remove(_image, _refined: bool): | |
| mask = _image['layers'][0] | |
| mask = mask.convert('L') | |
| _input = { | |
| 'img': _image['background'].convert('RGB'), | |
| 'mask': mask, | |
| } | |
| inpainting = pipeline(Tasks.image_inpainting, model='damo/cv_fft_inpainting_lama', refined=_refined) | |
| result = inpainting(_input) | |
| vis_img = result[OutputKeys.OUTPUT_IMG] | |
| vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) | |
| return vis_img, mask | |
| def bg_remove(_image, _type): | |
| input_image = _image['background'].convert('RGB') | |
| if _type == '人像': | |
| matting = pipeline(Tasks.portrait_matting, model='damo/cv_unet_image-matting') | |
| else: | |
| matting = pipeline(Tasks.universal_matting, model='damo/cv_unet_universal-matting') | |
| result = matting(input_image) | |
| vis_img = result[OutputKeys.OUTPUT_IMG] | |
| vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGRA2RGBA) | |
| return vis_img | |
| def text_to_image(_prompt: str, _n_prompt: str, _scheduler: str, _inference_steps: int, _w: int, _h: int, _guidance_scale: float): | |
| print('????????', _prompt, _scheduler, _inference_steps, _w, _h, _guidance_scale) | |
| t2i_pipeline = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ).to("cuda") | |
| t2i_pipeline.scheduler = SCHEDULERS[_scheduler].from_config(t2i_pipeline.scheduler.config) | |
| t2i_pipeline.enable_xformers_memory_efficient_attention() | |
| with torch.inference_mode(): | |
| result = t2i_pipeline( | |
| prompt=_prompt, | |
| negative_prompt=_n_prompt, | |
| num_inference_steps=_inference_steps, | |
| width=_w, | |
| height=_h, | |
| guidance_scale=_guidance_scale, | |
| ).images[0] | |
| return result | |
| def image_upscale(_image, _size: str): | |
| sr = pipeline(Tasks.image_super_resolution, model='damo/cv_rrdb_image-super-resolution') | |
| result = sr(_image['background'].convert('RGB')) | |
| vis_img = result[OutputKeys.OUTPUT_IMG] | |
| vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) | |
| return vis_img | |
| with gr.Blocks() as app: | |
| with gr.Tab('聊天'): | |
| chat_engine = gr.State(value=None) | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=600): | |
| chatbot = gr.ChatInterface( | |
| predict, | |
| multimodal=True, | |
| chatbot=gr.Chatbot(elem_id="chatbot", height=600, show_share_button=False), | |
| textbox=gr.MultimodalTextbox(lines=1), | |
| additional_inputs=[chat_engine] | |
| ) | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Accordion('参数设置', open=True): | |
| with gr.Column(): | |
| provider = gr.Dropdown( | |
| label='模型厂商', | |
| choices=['DeepSeek', 'OpenRouter', 'Tongyi'], | |
| value='DeepSeek', | |
| info='不同模型厂商参数,效果和价格略有不同,请先设置好对应模型厂商的 API Key。', | |
| ) | |
| def show_model_config_panel(_provider): | |
| if _provider == 'DeepSeek': | |
| with gr.Column(): | |
| model = gr.Dropdown( | |
| label='模型', | |
| choices=deep_seek_llm.support_models, | |
| value=deep_seek_llm.default_model | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=deep_seek_llm.default_temperature, | |
| label="Temperature", | |
| key="temperature", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1024, | |
| maximum=1024 * 20, | |
| step=128, | |
| value=deep_seek_llm.default_max_tokens, | |
| label="Max Tokens", | |
| key="max_tokens", | |
| ) | |
| model.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| temperature.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| max_tokens.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| if _provider == 'OpenRouter': | |
| with gr.Column(): | |
| model = gr.Dropdown( | |
| label='模型', | |
| choices=open_router_llm.support_models, | |
| value=open_router_llm.default_model | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=open_router_llm.default_temperature, | |
| label="Temperature", | |
| key="temperature", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1024, | |
| maximum=1024 * 20, | |
| step=128, | |
| value=open_router_llm.default_max_tokens, | |
| label="Max Tokens", | |
| key="max_tokens", | |
| ) | |
| model.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| temperature.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| max_tokens.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| if _provider == 'Tongyi': | |
| with gr.Column(): | |
| model = gr.Dropdown( | |
| label='模型', | |
| choices=tongyi_llm.support_models, | |
| value=tongyi_llm.default_model | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=tongyi_llm.default_temperature, | |
| label="Temperature", | |
| key="temperature", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1000, | |
| maximum=2000, | |
| step=100, | |
| value=tongyi_llm.default_max_tokens, | |
| label="Max Tokens", | |
| key="max_tokens", | |
| ) | |
| model.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| temperature.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| max_tokens.change( | |
| fn=update_chat, | |
| inputs=[provider, chat_engine, model, temperature, max_tokens], | |
| outputs=[chat_engine], | |
| ) | |
| with gr.Tab('图像编辑'): | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=600): | |
| image = gr.ImageMask( | |
| type='pil', | |
| brush=gr.Brush(colors=["rgba(255, 255, 255, 0.9)"]), | |
| ) | |
| with gr.Row(): | |
| mask_preview = gr.Image(label='蒙板预览') | |
| image_preview = gr.Image(label='图片预览') | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Accordion(label="物体移除"): | |
| object_remove_refined = gr.Checkbox(label="Refined(GPU)", info="只支持 GPU, 开启将获得更好的效果") | |
| object_remove_btn = gr.Button('物体移除', variant='primary') | |
| with gr.Accordion(label="背景移除"): | |
| bg_remove_type = gr.Radio(["人像", "通用"], label="类型", value='人像') | |
| bg_remove_btn = gr.Button('背景移除', variant='primary') | |
| with gr.Accordion(label="高清放大"): | |
| upscale_size = gr.Radio(["X2", "X4"], label="放大倍数", value='X2') | |
| upscale_btn = gr.Button('高清放大', variant='primary') | |
| object_remove_btn.click(fn=object_remove, inputs=[image, object_remove_refined], outputs=[image_preview, mask_preview]) | |
| bg_remove_btn.click(fn=bg_remove, inputs=[image, bg_remove_type], outputs=[image_preview]) | |
| upscale_btn.click(fn=image_upscale, inputs=[image, upscale_size], outputs=[image_preview]) | |
| with gr.Tab('画图(GPU)'): | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=600): | |
| image = gr.Image() | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Accordion(label="提示词", open=True): | |
| prompt = gr.Textbox(label="提示语", value="", lines=3) | |
| negative_prompt = gr.Textbox(label="负提示语", value="ugly", lines=2) | |
| with gr.Accordion(label="参数设置", open=False): | |
| scheduler = gr.Dropdown(label='scheduler', choices=list(SCHEDULERS.keys()), value='KarrasDPM') | |
| inference_steps = gr.Number(label='inference steps', value=22, minimum=1, maximum=100) | |
| width = gr.Dropdown(label='width', choices=[512, 768, 832, 896, 1024, 1152], value=1024) | |
| height = gr.Dropdown(label='height', choices=[512, 768, 832, 896, 1024, 1152], value=1024) | |
| guidance_scale = gr.Number(label='guidance scale', value=7.0, minimum=1.0, maximum=10.0) | |
| with gr.Row(variant='panel'): | |
| t2i_btn = gr.Button('🪄生成', variant='primary') | |
| t2i_btn.click(fn=text_to_image, inputs=[prompt, negative_prompt, scheduler, inference_steps, width, height, guidance_scale], outputs=[image]) | |
| app.launch(debug=settings.debug, show_api=False) | |