import base64 import os from io import BytesIO from typing import List import gradio as gr from openai import OpenAI from PIL import Image client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # 主要的绘画风格提示词 style_prompts = { "meticulous_painting": "Render the image in the meticulous 'gongbi' style: fine lines, ornate details, soft yet rich colors, elegant figures or court scenes, traditional Chinese painting texture.", "bluegreen_shanshui": "Apply the blue-and-green shanshui style: bold composition, layered mountain ranges, vibrant mineral pigments, poetic river landscapes, classic Chinese landscape painting.", "ink_wash": "Use the literati ink wash style: monochrome or subtle hues, expressive brushwork, poetic mood, atmospheric negative space, scholarly aesthetics.", "flower_bird": "Paint in the 'gongbi flower-and-bird' style: detailed flora and fauna, precise outlines, balanced composition, natural and symbolic elements, elegant and vibrant." } def generate_image_from_sketch(sketch: dict, selected_style: str) -> Image.Image: sketch_path = sketch["composite"] if not sketch_path: raise ValueError("No sketch provided.") # 基础prompt base_prompt = ( "Create a painting-style image based on the input sketch. " "The image should follow the structure of the sketch closely. " "Use soft brush strokes, traditional composition, and an aesthetic rooted in Chinese classical painting. " "The final result should look like a vivid traditional Chinese artwork. " ) # 拼接选中的风格prompt final_prompt = base_prompt + style_prompts[selected_style] # 调用OpenAI gpt-image-1模型 result = client.images.edit( model="gpt-image-1", image=[open(sketch_path, "rb")], prompt=final_prompt, size="1024x1536", ) # 解码base64响应为图片 image_base64 = result.data[0].b64_json image_bytes = base64.b64decode(image_base64) output_img = Image.open(BytesIO(image_bytes)) return output_img def complete_sketch(sketch: dict, recorded_sketches: List[dict]) -> (dict, List[dict]): sketch_path = sketch["composite"] if not sketch_path: raise ValueError("No sketch provided.") recorded_sketches.append(sketch_path) prompt = ( "Preserve the original drawing style and composition. Add one semantically valid and salient object to the sketch. " "Only make minimal, sensible additions that enhance the sketch without changing its overall layout. " "Do not fill in large areas or apply shading. The result should remain a clean, line-based black-and-white sketch." ) result = client.images.edit( model="gpt-image-1", image=[open(sketch_path, "rb")], prompt=prompt, size="1024x1536", ) image_base64 = result.data[0].b64_json image_bytes = base64.b64decode(image_base64) completed_sketch = Image.open(BytesIO(image_bytes)) completed_sketch_dict = { "composite": None, "layers": [completed_sketch], "background": None, } return completed_sketch_dict, recorded_sketches def prev_sketch(recorded_sketches: List[dict]) -> (dict, List[dict]): if not recorded_sketches: return {"composite": None, "layers": [], "background": None}, recorded_sketches last_sketch = recorded_sketches.pop() last_sketch_dict = { "composite": None, "layers": [last_sketch], "background": None, } return last_sketch_dict, recorded_sketches with gr.Blocks() as demo: recorded_sketches = gr.State(value=[]) with gr.Row(max_height=1000): with gr.Column(scale=2): with gr.Row(equal_height=True): # 这里是风格下拉菜单 style_dropdown = gr.Dropdown( label="Select style", choices=[ ("Meticulous painting (Gongbi)", "meticulous_painting"), ("Blue-green shanshui", "bluegreen_shanshui"), ("Ink wash", "ink_wash"), ("Flower-and-bird", "flower_bird"), ], value="ink_wash", scale=1, ) complete_btn = gr.Button("✏️ Sketch", min_width=20) generate_btn = gr.Button("🎨 Paint", min_width=20) prev_sketch_btn = gr.Button("🔙 Undo", min_width=20) clear_btn = gr.Button("🗑️ Delete", min_width=20) sketchpad = gr.Sketchpad( label="Draw something", type="filepath", layers=False, height=576, width=576, canvas_size=(1024, 1536), brush=gr.Brush( colors=["#000000"], color_mode="fixed", default_size=4, ), ) with gr.Column(scale=1): output = gr.Image( label="Generated Image", type="filepath", height=768, width=384, interactive=False, ) generate_btn.click( fn=lambda: ( gr.Button(interactive=False), gr.Button(interactive=False), gr.Button("⏳ Processing", interactive=False), gr.Button(interactive=False), ), inputs=None, outputs=[prev_sketch_btn, complete_btn, generate_btn, clear_btn], ).then(fn=generate_image_from_sketch, inputs=[sketchpad, style_dropdown], outputs=output).then( lambda: ( gr.Button(interactive=True), gr.Button(interactive=True), gr.Button("🎨 Paint", interactive=True), gr.Button(interactive=True), ), inputs=None, outputs=[prev_sketch_btn, complete_btn, generate_btn, clear_btn], ) complete_btn.click( fn=lambda: ( gr.Button(interactive=False), gr.Button("⏳ Processing", interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), ), inputs=None, outputs=[prev_sketch_btn, complete_btn, generate_btn, clear_btn], ).then( fn=complete_sketch, inputs=[sketchpad, recorded_sketches], outputs=[sketchpad, recorded_sketches], ).then( lambda: ( gr.Button(interactive=True), gr.Button("✏️ Sketch", interactive=True), gr.Button(interactive=True), gr.Button(interactive=True), ), inputs=None, outputs=[prev_sketch_btn, complete_btn, generate_btn, clear_btn], ) prev_sketch_btn.click( fn=prev_sketch, inputs=[recorded_sketches], outputs=[sketchpad, recorded_sketches], ) clear_btn.click( fn=lambda: ({"composite": None, "layers": [], "background": None}, None, []), inputs=None, outputs=[sketchpad, output, recorded_sketches], ) if __name__ == "__main__": demo.launch()