Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import tempfile | |
| import sys | |
| import time | |
| from inferencer import Inferencer | |
| from accelerate.utils import set_seed | |
| from huggingface_hub import snapshot_download | |
| model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B") | |
| model_path = os.path.join(model_dir,"pytorch_model.bin") | |
| ckpt_name = "UniPic" | |
| inferencer = Inferencer( | |
| config_file="qwen2_5_1_5b_kl16_mar_h.py", | |
| model_path=model_path, | |
| image_size=1024, | |
| #cfg_prompt="Generate an image.", | |
| ) | |
| TEMP_DIR = tempfile.mkdtemp() | |
| print(f"Temporary directory created at: {TEMP_DIR}") | |
| def save_temp_image(pil_img): | |
| # 只支持512——>1024的编辑 | |
| # img_resized = pil_img.resize((512, 512)) | |
| path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") | |
| pil_img.save(path, format="PNG") | |
| return path | |
| def handle_image_upload(file, history): | |
| if file is None: | |
| return None, history | |
| file_path = file.name if hasattr(file, "name") else file | |
| pil_img = Image.open(file_path) | |
| saved_path = save_temp_image(pil_img) | |
| return saved_path, history + [((saved_path,), None)] | |
| def clear_all(): | |
| for file in os.listdir(TEMP_DIR): | |
| path = os.path.join(TEMP_DIR, file) | |
| try: | |
| if os.path.isfile(path): | |
| os.remove(path) | |
| except Exception as e: | |
| print(f"Failed to delete temp file: {path}, error: {e}") | |
| return [], None, "Understand Image" | |
| def extract_assistant_reply(full_text): | |
| if "assistant" in full_text: | |
| parts = full_text.strip().split("assistant") | |
| return parts[-1].lstrip(":").strip() | |
| return full_text.replace("<|im_end|>", "").strip() | |
| def on_submit(history, user_msg, img_path, mode, grid_size=1): | |
| # 把 history 中的 tuples 全部换成 lists | |
| updated_history = [list(item) for item in history] | |
| user_msg = user_msg.strip() | |
| updated_history.append([user_msg, None]) | |
| # set_seed(42) | |
| try: | |
| if mode == "Understand Image": | |
| if img_path is None: | |
| updated_history.append([None, "⚠️ Please upload or generate an image first."]) | |
| return updated_history, "", img_path | |
| raw = ( | |
| inferencer.query_image(Image.open(img_path), user_msg) | |
| if img_path else inferencer.query_text(user_msg) | |
| ) | |
| reply = extract_assistant_reply(raw) | |
| updated_history.append([None, reply]) | |
| return updated_history, "", img_path | |
| elif mode == "Generate Image": | |
| if not user_msg: | |
| updated_history.append([None, "⚠️ Please enter a prompt."]) | |
| return updated_history, "", img_path | |
| imgs = inferencer.gen_image( | |
| raw_prompt=user_msg, | |
| images_to_generate=grid_size**2, | |
| cfg=3.0, | |
| num_iter=48, | |
| cfg_schedule="constant", | |
| temperature=1.0, | |
| ) | |
| paths = [save_temp_image(img) for img in imgs] | |
| # 多图必须是列表格式 | |
| updated_history.append([None, paths]) | |
| return updated_history, "", paths[-1] | |
| elif mode == "Edit Image": | |
| if img_path is None: | |
| updated_history.append([None, "⚠️ Please upload or generate an image first."]) | |
| return updated_history, "", img_path | |
| if not user_msg: | |
| updated_history.append([None, "⚠️ Please enter an edit instruction."]) | |
| return updated_history, "", img_path | |
| img = Image.open(img_path) | |
| imgs = inferencer.edit_image( | |
| source_image=img, | |
| prompt=user_msg, | |
| cfg=3.0, | |
| cfg_prompt="repeat this image.", | |
| cfg_schedule="constant", | |
| temperature=0.85, | |
| grid_size=grid_size, | |
| num_iter=48, | |
| ) | |
| paths = [save_temp_image(img) for img in imgs] | |
| updated_history.append([None, paths]) | |
| return updated_history, "", paths[-1] | |
| except Exception as e: | |
| updated_history.append([None, f"⚠️ Failed to process: {e}"]) | |
| return updated_history, "", img_path | |
| CSS = """ | |
| /* 整体布局:上下两块 */ | |
| .gradio-container { | |
| display: flex !important; | |
| flex-direction: column; | |
| height: 100vh; | |
| margin: 0; | |
| padding: 0; | |
| } | |
| .gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */ | |
| flex: 1 1 auto; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| /* 聊天 tab */ | |
| #tab_item_4, #tab_item_5 { | |
| display: flex; | |
| flex-direction: column; | |
| flex: 1 1 auto; | |
| overflow: hidden; /* 防止出现双滚动条 */ | |
| padding: 8px; | |
| } | |
| /* Chatbot 撑满 */ | |
| #chatbot1, #chatbot2{ | |
| flex-grow: 1 !important; | |
| max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */ | |
| overflow-y: auto !important; /* 当内容溢出时显示滚动条 */ | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin-bottom: 8px; | |
| } | |
| /* 图片消息放大 */ | |
| #chatbot1 img, #chatbot2 img { | |
| max-width: 80vw !important; | |
| height: auto !important; | |
| border-radius: 4px; | |
| } | |
| /* 底部输入区:固定高度 */ | |
| .input-row { | |
| flex: 0 0 auto; | |
| display: flex; | |
| align-items: center; | |
| padding: 8px; | |
| border-top: 1px solid #eee; | |
| background: #fafafa; | |
| } | |
| /* 文本框和按钮排布 */ | |
| .input-row .textbox-col { flex: 5; } | |
| .input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; } | |
| /* 文本框样式 */ | |
| .gr-text-input { | |
| width: 100% !important; | |
| border-radius: 18px !important; | |
| padding: 8px 16px !important; | |
| border: 1px solid #ddd !important; | |
| font-size: 16px !important; | |
| } | |
| /* 按钮和上传组件样式 */ | |
| .gr-button, .gr-upload { | |
| width: 100% !important; | |
| border-radius: 18px !important; | |
| padding: 8px 16px !important; | |
| font-size: 16px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS) as demo: | |
| img_state = gr.State(value=None) | |
| mode_state = gr.State(value="Understand Image") | |
| with gr.Tabs(): | |
| with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot1", | |
| show_label=False, | |
| avatar_images=( | |
| "user.png", | |
| "ai.png", | |
| ), | |
| ) | |
| with gr.Row(): | |
| mode_selector = gr.Radio( | |
| choices=["Generate Image","Edit Image","Understand Image"], | |
| value="Generate Image", | |
| label="Mode", | |
| interactive=True, | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| with gr.Column(elem_classes="textbox-col"): | |
| user_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| lines=1, | |
| ) | |
| with gr.Column(elem_classes="upload-col"): | |
| image_input = gr.UploadButton( | |
| "📷 Upload Image", | |
| file_types=["image"], | |
| file_count="single", | |
| type="filepath", | |
| ) | |
| with gr.Column(elem_classes="clear-col"): | |
| clear_btn = gr.Button("🧹 Clear History") | |
| user_input.submit( | |
| on_submit, | |
| [chatbot, user_input, img_state, mode_selector], | |
| [chatbot, user_input, img_state], | |
| ) | |
| image_input.upload( | |
| handle_image_upload, [image_input, chatbot], [img_state, chatbot] | |
| ) | |
| clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) | |
| # if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", share=True, debug=True, server_port=7689) | |
| demo.launch() |