import io import json import base64 from pathlib import Path import gradio as gr import plotly.graph_objects as go from PIL import Image from openai import OpenAI SYSTEM_PROMPT = """ You are OmniGen AI. Choose exactly one response mode and return valid JSON only. For normal requests: {"type":"text","content":""} For image-worthy requests: {"type":"image","prompt":""} For chart/data-viz requests: {"type":"chart","title":"","data":[{"x":[...],"y":[...],"label":""}]} Rules: - Return JSON only. - Choose exactly one type. - Prefer text unless an image or chart is clearly better. - For charts, provide concise but usable demo data when exact data is not supplied. - Do not use markdown fences. """ def load_css(): css_file = Path("style.css") if css_file.exists(): return css_file.read_text(encoding="utf-8") return "" CUSTOM_CSS = load_css() def build_messages(history, user_msg): messages = [{"role": "system", "content": SYSTEM_PROMPT}] for user_text, assistant_text in history: messages.append({"role": "user", "content": user_text}) messages.append({"role": "assistant", "content": assistant_text}) messages.append({"role": "user", "content": user_msg}) return messages def make_client(api_key): if not api_key or not api_key.strip(): raise gr.Error("Please enter your OpenAI API key.") return OpenAI(api_key=api_key.strip()) def route_request(client, messages): response = client.chat.completions.create( model="gpt-4.1", messages=messages, temperature=0.5, response_format={"type": "json_object"}, ) content = response.choices[0].message.content or "" try: return json.loads(content) except json.JSONDecodeError as exc: raise gr.Error(f"Model returned invalid JSON: {exc}") from exc def generate_image(client, prompt): image_response = client.images.generate( model="gpt-image-1", prompt=prompt, size="1024x1024", ) image_b64 = image_response.data[0].b64_json if not image_b64: raise gr.Error("Image generation failed: no image data returned.") image_bytes = base64.b64decode(image_b64) return Image.open(io.BytesIO(image_bytes)) def build_chart(chart_payload): data = chart_payload.get("data", []) if not isinstance(data, list) or not data: raise gr.Error("Chart generation failed: missing or invalid data series.") fig = go.Figure() for series in data: x_vals = series.get("x", []) y_vals = series.get("y", []) label = series.get("label", "Series") fig.add_trace( go.Scatter( x=x_vals, y=y_vals, mode="lines+markers", name=label, ) ) fig.update_layout( title=chart_payload.get("title", "Generated Chart"), template="plotly_white", margin=dict(l=40, r=40, t=60, b=40), ) return fig def respond(api_key, user_msg, history): history = history or [] if not user_msg or not user_msg.strip(): raise gr.Error("Please enter a prompt.") clean_user_msg = user_msg.strip() client = make_client(api_key) messages = build_messages(history, clean_user_msg) routed = route_request(client, messages) result_type = routed.get("type") image_value = None chart_value = None if result_type == "text": text = (routed.get("content") or "").strip() if not text: text = "No text was generated." history.append([clean_user_msg, text]) elif result_type == "image": prompt = (routed.get("prompt") or "").strip() if not prompt: raise gr.Error("Image mode was selected, but no prompt was returned.") image_value = generate_image(client, prompt) history.append([clean_user_msg, f"Generated an image for: {prompt}"]) elif result_type == "chart": chart_value = build_chart(routed) title = routed.get("title", "Generated Chart") history.append([clean_user_msg, f"Generated chart: {title}"]) else: raise gr.Error(f"Unsupported response type: {result_type}") return history, history, image_value, chart_value, "" with gr.Blocks(css=CUSTOM_CSS) as demo: gr.Markdown( """ # 🧠 OmniGen AI Studio Generate text, images, or charts from a single prompt using GPT-4.1. """ ) api_key = gr.Textbox( label="OpenAI API Key", type="password", placeholder="sk-...", ) chatbot = gr.Chatbot(label="Conversation", height=420) with gr.Row(): user_msg = gr.Textbox( label="Your Prompt", placeholder="Ask for a product description, an image concept, or a simple chart...", scale=5, ) send_btn = gr.Button("Generate", variant="primary") with gr.Row(): image_out = gr.Image(label="Generated Image", type="pil") chart_out = gr.Plot(label="Generated Chart") history_state = gr.State([]) send_btn.click( fn=respond, inputs=[api_key, user_msg, history_state], outputs=[chatbot, history_state, image_out, chart_out, user_msg], ) user_msg.submit( fn=respond, inputs=[api_key, user_msg, history_state], outputs=[chatbot, history_state, image_out, chart_out, user_msg], ) if __name__ == "__main__": demo.queue().launch()