| import io |
| import json |
| import base64 |
| from pathlib import Path |
|
|
| import gradio as gr |
| import plotly.graph_objects as go |
| from PIL import Image |
| from openai import OpenAI |
|
|
|
|
| SYSTEM_PROMPT = """ |
| You are OmniGen AI. |
| |
| Choose exactly one response mode and return valid JSON only. |
| |
| For normal requests: |
| {"type":"text","content":"<helpful answer>"} |
| |
| For image-worthy requests: |
| {"type":"image","prompt":"<clear, detailed image prompt>"} |
| |
| For chart/data-viz requests: |
| {"type":"chart","title":"<chart title>","data":[{"x":[...],"y":[...],"label":"<series name>"}]} |
| |
| Rules: |
| - Return JSON only. |
| - Choose exactly one type. |
| - Prefer text unless an image or chart is clearly better. |
| - For charts, provide concise but usable demo data when exact data is not supplied. |
| - Do not use markdown fences. |
| """ |
|
|
|
|
| def load_css(): |
| css_file = Path("style.css") |
| if css_file.exists(): |
| return css_file.read_text(encoding="utf-8") |
| return "" |
|
|
|
|
| CUSTOM_CSS = load_css() |
|
|
|
|
| def build_messages(history, user_msg): |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
| for user_text, assistant_text in history: |
| messages.append({"role": "user", "content": user_text}) |
| messages.append({"role": "assistant", "content": assistant_text}) |
| messages.append({"role": "user", "content": user_msg}) |
| return messages |
|
|
|
|
| def make_client(api_key): |
| if not api_key or not api_key.strip(): |
| raise gr.Error("Please enter your OpenAI API key.") |
| return OpenAI(api_key=api_key.strip()) |
|
|
|
|
| def route_request(client, messages): |
| response = client.chat.completions.create( |
| model="gpt-4.1", |
| messages=messages, |
| temperature=0.5, |
| response_format={"type": "json_object"}, |
| ) |
| content = response.choices[0].message.content or "" |
| try: |
| return json.loads(content) |
| except json.JSONDecodeError as exc: |
| raise gr.Error(f"Model returned invalid JSON: {exc}") from exc |
|
|
|
|
| def generate_image(client, prompt): |
| image_response = client.images.generate( |
| model="gpt-image-1", |
| prompt=prompt, |
| size="1024x1024", |
| ) |
|
|
| image_b64 = image_response.data[0].b64_json |
| if not image_b64: |
| raise gr.Error("Image generation failed: no image data returned.") |
|
|
| image_bytes = base64.b64decode(image_b64) |
| return Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
| def build_chart(chart_payload): |
| data = chart_payload.get("data", []) |
| if not isinstance(data, list) or not data: |
| raise gr.Error("Chart generation failed: missing or invalid data series.") |
|
|
| fig = go.Figure() |
|
|
| for series in data: |
| x_vals = series.get("x", []) |
| y_vals = series.get("y", []) |
| label = series.get("label", "Series") |
|
|
| fig.add_trace( |
| go.Scatter( |
| x=x_vals, |
| y=y_vals, |
| mode="lines+markers", |
| name=label, |
| ) |
| ) |
|
|
| fig.update_layout( |
| title=chart_payload.get("title", "Generated Chart"), |
| template="plotly_white", |
| margin=dict(l=40, r=40, t=60, b=40), |
| ) |
| return fig |
|
|
|
|
| def respond(api_key, user_msg, history): |
| history = history or [] |
|
|
| if not user_msg or not user_msg.strip(): |
| raise gr.Error("Please enter a prompt.") |
|
|
| clean_user_msg = user_msg.strip() |
| client = make_client(api_key) |
| messages = build_messages(history, clean_user_msg) |
| routed = route_request(client, messages) |
|
|
| result_type = routed.get("type") |
| image_value = None |
| chart_value = None |
|
|
| if result_type == "text": |
| text = (routed.get("content") or "").strip() |
| if not text: |
| text = "No text was generated." |
| history.append([clean_user_msg, text]) |
|
|
| elif result_type == "image": |
| prompt = (routed.get("prompt") or "").strip() |
| if not prompt: |
| raise gr.Error("Image mode was selected, but no prompt was returned.") |
| image_value = generate_image(client, prompt) |
| history.append([clean_user_msg, f"Generated an image for: {prompt}"]) |
|
|
| elif result_type == "chart": |
| chart_value = build_chart(routed) |
| title = routed.get("title", "Generated Chart") |
| history.append([clean_user_msg, f"Generated chart: {title}"]) |
|
|
| else: |
| raise gr.Error(f"Unsupported response type: {result_type}") |
|
|
| return history, history, image_value, chart_value, "" |
|
|
|
|
| with gr.Blocks(css=CUSTOM_CSS) as demo: |
| gr.Markdown( |
| """ |
| # 🧠 OmniGen AI Studio |
| Generate text, images, or charts from a single prompt using GPT-4.1. |
| """ |
| ) |
|
|
| api_key = gr.Textbox( |
| label="OpenAI API Key", |
| type="password", |
| placeholder="sk-...", |
| ) |
|
|
| chatbot = gr.Chatbot(label="Conversation", height=420) |
|
|
| with gr.Row(): |
| user_msg = gr.Textbox( |
| label="Your Prompt", |
| placeholder="Ask for a product description, an image concept, or a simple chart...", |
| scale=5, |
| ) |
| send_btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Row(): |
| image_out = gr.Image(label="Generated Image", type="pil") |
| chart_out = gr.Plot(label="Generated Chart") |
|
|
| history_state = gr.State([]) |
|
|
| send_btn.click( |
| fn=respond, |
| inputs=[api_key, user_msg, history_state], |
| outputs=[chatbot, history_state, image_out, chart_out, user_msg], |
| ) |
|
|
| user_msg.submit( |
| fn=respond, |
| inputs=[api_key, user_msg, history_state], |
| outputs=[chatbot, history_state, image_out, chart_out, user_msg], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |