| import gradio as gr |
| from PIL import Image |
| import os |
| import tempfile |
| import sys |
| import time |
| from inferencer import Inferencer |
| from accelerate.utils import set_seed |
| from huggingface_hub import snapshot_download |
|
|
| model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B") |
| model_path = os.path.join(model_dir,"pytorch_model.bin") |
| ckpt_name = "UniPic" |
|
|
| inferencer = Inferencer( |
| config_file="qwen2_5_1_5b_kl16_mar_h.py", |
| model_path=model_path, |
| image_size=1024, |
| |
| ) |
|
|
| TEMP_DIR = tempfile.mkdtemp() |
| print(f"Temporary directory created at: {TEMP_DIR}") |
|
|
| def save_temp_image(pil_img): |
| |
| |
| path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") |
| pil_img.save(path, format="PNG") |
| return path |
|
|
| def handle_image_upload(file, history): |
| if file is None: |
| return None, history |
| file_path = file.name if hasattr(file, "name") else file |
| pil_img = Image.open(file_path) |
| saved_path = save_temp_image(pil_img) |
| return saved_path, history + [((saved_path,), None)] |
|
|
| def clear_all(): |
| for file in os.listdir(TEMP_DIR): |
| path = os.path.join(TEMP_DIR, file) |
| try: |
| if os.path.isfile(path): |
| os.remove(path) |
| except Exception as e: |
| print(f"Failed to delete temp file: {path}, error: {e}") |
| return [], None, "Understand Image" |
|
|
| def extract_assistant_reply(full_text): |
| if "assistant" in full_text: |
| parts = full_text.strip().split("assistant") |
| return parts[-1].lstrip(":").strip() |
| return full_text.replace("<|im_end|>", "").strip() |
|
|
| def on_submit(history, user_msg, img_path, mode, grid_size=1): |
| |
| updated_history = [list(item) for item in history] |
| user_msg = user_msg.strip() |
| updated_history.append([user_msg, None]) |
| |
|
|
| try: |
| if mode == "Understand Image": |
| if img_path is None: |
| updated_history.append([None, "⚠️ Please upload or generate an image first."]) |
| return updated_history, "", img_path |
| |
| raw = ( |
| inferencer.query_image(Image.open(img_path), user_msg) |
| if img_path else inferencer.query_text(user_msg) |
| ) |
| reply = extract_assistant_reply(raw) |
| updated_history.append([None, reply]) |
| return updated_history, "", img_path |
|
|
| elif mode == "Generate Image": |
| if not user_msg: |
| updated_history.append([None, "⚠️ Please enter a prompt."]) |
| return updated_history, "", img_path |
|
|
| imgs = inferencer.gen_image( |
| raw_prompt=user_msg, |
| images_to_generate=grid_size**2, |
| cfg=3.0, |
| num_iter=48, |
| cfg_schedule="constant", |
| temperature=1.0, |
| ) |
| paths = [save_temp_image(img) for img in imgs] |
| |
| updated_history.append([None, paths]) |
| return updated_history, "", paths[-1] |
|
|
| elif mode == "Edit Image": |
| if img_path is None: |
| updated_history.append([None, "⚠️ Please upload or generate an image first."]) |
| return updated_history, "", img_path |
| if not user_msg: |
| updated_history.append([None, "⚠️ Please enter an edit instruction."]) |
| return updated_history, "", img_path |
|
|
| img = Image.open(img_path) |
| |
| imgs = inferencer.edit_image( |
| source_image=img, |
| prompt=user_msg, |
| cfg=3.0, |
| cfg_prompt="repeat this image.", |
| cfg_schedule="constant", |
| temperature=0.85, |
| grid_size=grid_size, |
| num_iter=48, |
| ) |
| paths = [save_temp_image(img) for img in imgs] |
| updated_history.append([None, paths]) |
| return updated_history, "", paths[-1] |
|
|
| except Exception as e: |
| updated_history.append([None, f"⚠️ Failed to process: {e}"]) |
| return updated_history, "", img_path |
|
|
| CSS = """ |
| /* 整体布局:上下两块 */ |
| .gradio-container { |
| display: flex !important; |
| flex-direction: column; |
| height: 100vh; |
| margin: 0; |
| padding: 0; |
| } |
| .gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */ |
| flex: 1 1 auto; |
| display: flex; |
| flex-direction: column; |
| } |
| |
| /* 聊天 tab */ |
| #tab_item_4, #tab_item_5 { |
| display: flex; |
| flex-direction: column; |
| flex: 1 1 auto; |
| overflow: hidden; /* 防止出现双滚动条 */ |
| padding: 8px; |
| } |
| |
| /* Chatbot 撑满 */ |
| #chatbot1, #chatbot2{ |
| flex-grow: 1 !important; |
| max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */ |
| overflow-y: auto !important; /* 当内容溢出时显示滚动条 */ |
| border: 1px solid #ddd; |
| border-radius: 8px; |
| padding: 12px; |
| margin-bottom: 8px; |
| } |
| |
| /* 图片消息放大 */ |
| #chatbot1 img, #chatbot2 img { |
| max-width: 80vw !important; |
| height: auto !important; |
| border-radius: 4px; |
| } |
| |
| /* 底部输入区:固定高度 */ |
| .input-row { |
| flex: 0 0 auto; |
| display: flex; |
| align-items: center; |
| padding: 8px; |
| border-top: 1px solid #eee; |
| background: #fafafa; |
| } |
| |
| /* 文本框和按钮排布 */ |
| .input-row .textbox-col { flex: 5; } |
| .input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; } |
| |
| /* 文本框样式 */ |
| .gr-text-input { |
| width: 100% !important; |
| border-radius: 18px !important; |
| padding: 8px 16px !important; |
| border: 1px solid #ddd !important; |
| font-size: 16px !important; |
| } |
| |
| /* 按钮和上传组件样式 */ |
| .gr-button, .gr-upload { |
| width: 100% !important; |
| border-radius: 18px !important; |
| padding: 8px 16px !important; |
| font-size: 16px !important; |
| } |
| """ |
|
|
| with gr.Blocks(css=CSS) as demo: |
| img_state = gr.State(value=None) |
| mode_state = gr.State(value="Understand Image") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"): |
| chatbot = gr.Chatbot( |
| elem_id="chatbot1", |
| show_label=False, |
| avatar_images=( |
| "user.png", |
| "ai.png", |
| ), |
| ) |
| with gr.Row(): |
| mode_selector = gr.Radio( |
| choices=["Generate Image","Edit Image","Understand Image"], |
| value="Generate Image", |
| label="Mode", |
| interactive=True, |
| ) |
|
|
| with gr.Row(elem_classes="input-row"): |
| with gr.Column(elem_classes="textbox-col"): |
| user_input = gr.Textbox( |
| placeholder="Type your message here...", |
| show_label=False, |
| lines=1, |
| ) |
| with gr.Column(elem_classes="upload-col"): |
| image_input = gr.UploadButton( |
| "📷 Upload Image", |
| file_types=["image"], |
| file_count="single", |
| type="filepath", |
| ) |
| with gr.Column(elem_classes="clear-col"): |
| clear_btn = gr.Button("🧹 Clear History") |
|
|
| user_input.submit( |
| on_submit, |
| [chatbot, user_input, img_state, mode_selector], |
| [chatbot, user_input, img_state], |
| ) |
| |
| image_input.upload( |
| handle_image_upload, [image_input, chatbot], [img_state, chatbot] |
| ) |
| clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) |
|
|
| |
| |
| demo.launch() |