OmniGen / app.py
ZENLLC's picture
Update app.py
4c814f5 verified
import io
import json
import base64
from pathlib import Path
import gradio as gr
import plotly.graph_objects as go
from PIL import Image
from openai import OpenAI
SYSTEM_PROMPT = """
You are OmniGen AI.
Choose exactly one response mode and return valid JSON only.
For normal requests:
{"type":"text","content":"<helpful answer>"}
For image-worthy requests:
{"type":"image","prompt":"<clear, detailed image prompt>"}
For chart/data-viz requests:
{"type":"chart","title":"<chart title>","data":[{"x":[...],"y":[...],"label":"<series name>"}]}
Rules:
- Return JSON only.
- Choose exactly one type.
- Prefer text unless an image or chart is clearly better.
- For charts, provide concise but usable demo data when exact data is not supplied.
- Do not use markdown fences.
"""
def load_css():
css_file = Path("style.css")
if css_file.exists():
return css_file.read_text(encoding="utf-8")
return ""
CUSTOM_CSS = load_css()
def build_messages(history, user_msg):
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for user_text, assistant_text in history:
messages.append({"role": "user", "content": user_text})
messages.append({"role": "assistant", "content": assistant_text})
messages.append({"role": "user", "content": user_msg})
return messages
def make_client(api_key):
if not api_key or not api_key.strip():
raise gr.Error("Please enter your OpenAI API key.")
return OpenAI(api_key=api_key.strip())
def route_request(client, messages):
response = client.chat.completions.create(
model="gpt-4.1",
messages=messages,
temperature=0.5,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content or ""
try:
return json.loads(content)
except json.JSONDecodeError as exc:
raise gr.Error(f"Model returned invalid JSON: {exc}") from exc
def generate_image(client, prompt):
image_response = client.images.generate(
model="gpt-image-1",
prompt=prompt,
size="1024x1024",
)
image_b64 = image_response.data[0].b64_json
if not image_b64:
raise gr.Error("Image generation failed: no image data returned.")
image_bytes = base64.b64decode(image_b64)
return Image.open(io.BytesIO(image_bytes))
def build_chart(chart_payload):
data = chart_payload.get("data", [])
if not isinstance(data, list) or not data:
raise gr.Error("Chart generation failed: missing or invalid data series.")
fig = go.Figure()
for series in data:
x_vals = series.get("x", [])
y_vals = series.get("y", [])
label = series.get("label", "Series")
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="lines+markers",
name=label,
)
)
fig.update_layout(
title=chart_payload.get("title", "Generated Chart"),
template="plotly_white",
margin=dict(l=40, r=40, t=60, b=40),
)
return fig
def respond(api_key, user_msg, history):
history = history or []
if not user_msg or not user_msg.strip():
raise gr.Error("Please enter a prompt.")
clean_user_msg = user_msg.strip()
client = make_client(api_key)
messages = build_messages(history, clean_user_msg)
routed = route_request(client, messages)
result_type = routed.get("type")
image_value = None
chart_value = None
if result_type == "text":
text = (routed.get("content") or "").strip()
if not text:
text = "No text was generated."
history.append([clean_user_msg, text])
elif result_type == "image":
prompt = (routed.get("prompt") or "").strip()
if not prompt:
raise gr.Error("Image mode was selected, but no prompt was returned.")
image_value = generate_image(client, prompt)
history.append([clean_user_msg, f"Generated an image for: {prompt}"])
elif result_type == "chart":
chart_value = build_chart(routed)
title = routed.get("title", "Generated Chart")
history.append([clean_user_msg, f"Generated chart: {title}"])
else:
raise gr.Error(f"Unsupported response type: {result_type}")
return history, history, image_value, chart_value, ""
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown(
"""
# 🧠 OmniGen AI Studio
Generate text, images, or charts from a single prompt using GPT-4.1.
"""
)
api_key = gr.Textbox(
label="OpenAI API Key",
type="password",
placeholder="sk-...",
)
chatbot = gr.Chatbot(label="Conversation", height=420)
with gr.Row():
user_msg = gr.Textbox(
label="Your Prompt",
placeholder="Ask for a product description, an image concept, or a simple chart...",
scale=5,
)
send_btn = gr.Button("Generate", variant="primary")
with gr.Row():
image_out = gr.Image(label="Generated Image", type="pil")
chart_out = gr.Plot(label="Generated Chart")
history_state = gr.State([])
send_btn.click(
fn=respond,
inputs=[api_key, user_msg, history_state],
outputs=[chatbot, history_state, image_out, chart_out, user_msg],
)
user_msg.submit(
fn=respond,
inputs=[api_key, user_msg, history_state],
outputs=[chatbot, history_state, image_out, chart_out, user_msg],
)
if __name__ == "__main__":
demo.queue().launch()