Update app.py
Browse files
app.py
CHANGED
|
@@ -37,7 +37,7 @@ BLOCKLIST = ["<script", "</script>", "{{", "}}"]
|
|
| 37 |
|
| 38 |
|
| 39 |
# -----------------------------
|
| 40 |
-
# Lazy Imports
|
| 41 |
# -----------------------------
|
| 42 |
def _lazy_import_openai():
|
| 43 |
try:
|
|
@@ -56,7 +56,7 @@ def _lazy_import_gemini():
|
|
| 56 |
|
| 57 |
|
| 58 |
# -----------------------------
|
| 59 |
-
#
|
| 60 |
# -----------------------------
|
| 61 |
def is_blocked(text: str) -> bool:
|
| 62 |
if not text:
|
|
@@ -66,7 +66,6 @@ def is_blocked(text: str) -> bool:
|
|
| 66 |
|
| 67 |
|
| 68 |
def pil_to_base64(image) -> str:
|
| 69 |
-
"""Convert PIL image to base64 JPEG for potential API usage."""
|
| 70 |
buffer = BytesIO()
|
| 71 |
image.convert("RGB").save(buffer, format="JPEG", quality=92)
|
| 72 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
@@ -77,16 +76,10 @@ def approx_tokens_from_chars(text: str) -> int:
|
|
| 77 |
|
| 78 |
|
| 79 |
def estimate_cost(provider_label: str, model: str, prompt: str, reply: str) -> float:
|
| 80 |
-
"""
|
| 81 |
-
Super rough cost estimator. Tune to your account reality.
|
| 82 |
-
Using blended, illustrative CPMs for demo purposes only.
|
| 83 |
-
"""
|
| 84 |
toks = approx_tokens_from_chars(prompt) + approx_tokens_from_chars(reply)
|
| 85 |
if provider_label.startswith("OpenAI"):
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Google/Gemini placeholder: $5 / 1M tokens
|
| 89 |
-
return round(toks / 1_000_000.0 * 5.0, 4)
|
| 90 |
|
| 91 |
|
| 92 |
# -----------------------------
|
|
@@ -102,21 +95,36 @@ def call_openai_chat(
|
|
| 102 |
max_tokens: int,
|
| 103 |
) -> str:
|
| 104 |
"""
|
| 105 |
-
Calls OpenAI Chat Completions.
|
|
|
|
| 106 |
"""
|
| 107 |
OpenAI = _lazy_import_openai()
|
| 108 |
client = OpenAI(api_key=api_key)
|
| 109 |
|
| 110 |
-
messages = [{"role": "system", "content": system_prompt.strip() or SYSTEM_DEFAULT}]
|
| 111 |
-
messages.extend(history_messages)
|
| 112 |
messages.append({"role": "user", "content": user_message})
|
| 113 |
|
| 114 |
-
|
| 115 |
model=(model.strip() or DEFAULT_OPENAI_MODEL),
|
| 116 |
messages=messages,
|
| 117 |
temperature=float(temperature),
|
| 118 |
-
max_tokens=int(max_tokens),
|
| 119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
return resp.choices[0].message.content
|
| 121 |
|
| 122 |
|
|
@@ -128,9 +136,6 @@ def call_gemini_generate(
|
|
| 128 |
image=None,
|
| 129 |
temperature: float = 0.4,
|
| 130 |
) -> str:
|
| 131 |
-
"""
|
| 132 |
-
Calls Gemini/Nano-Banana. Supports optional image (PIL) as part.
|
| 133 |
-
"""
|
| 134 |
genai = _lazy_import_gemini()
|
| 135 |
genai.configure(api_key=api_key)
|
| 136 |
|
|
@@ -150,22 +155,17 @@ def call_gemini_generate(
|
|
| 150 |
|
| 151 |
parts: List[Any] = [user_message or ""]
|
| 152 |
if image is not None:
|
| 153 |
-
# google-generativeai accepts PIL image directly as a part
|
| 154 |
parts.append(image)
|
| 155 |
|
| 156 |
resp = model_obj.generate_content(parts)
|
| 157 |
|
| 158 |
-
# Prefer .text if available
|
| 159 |
if hasattr(resp, "text") and resp.text:
|
| 160 |
return resp.text
|
| 161 |
-
|
| 162 |
-
# Fallback: candidates/parts
|
| 163 |
cand = getattr(resp, "candidates", None)
|
| 164 |
if cand and getattr(cand[0], "content", None):
|
| 165 |
parts = getattr(cand[0].content, "parts", None)
|
| 166 |
if parts and hasattr(parts[0], "text"):
|
| 167 |
return parts[0].text
|
| 168 |
-
|
| 169 |
return "(No response text returned.)"
|
| 170 |
|
| 171 |
|
|
@@ -173,9 +173,6 @@ def call_gemini_generate(
|
|
| 173 |
# Orchestration
|
| 174 |
# -----------------------------
|
| 175 |
def to_openai_history(gradio_history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
|
| 176 |
-
"""
|
| 177 |
-
Convert Gradio Chatbot history ([(user, assistant), ...]) to OpenAI messages.
|
| 178 |
-
"""
|
| 179 |
oai: List[Dict[str, str]] = []
|
| 180 |
for user_msg, ai_msg in gradio_history or []:
|
| 181 |
if user_msg:
|
|
@@ -197,9 +194,6 @@ def infer(
|
|
| 197 |
max_tokens: int,
|
| 198 |
history: List[Tuple[str, str]],
|
| 199 |
):
|
| 200 |
-
"""
|
| 201 |
-
Main entry: routes to the chosen provider, returns updated chat, latency, cost.
|
| 202 |
-
"""
|
| 203 |
if not (user_message and user_message.strip()):
|
| 204 |
raise gr.Error("Please enter a prompt (or pick a starter prompt).")
|
| 205 |
if is_blocked(user_message):
|
|
@@ -296,7 +290,7 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
|
|
| 296 |
)
|
| 297 |
with gr.Row():
|
| 298 |
temperature = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Temperature")
|
| 299 |
-
max_tokens = gr.Slider(128, 4096, value=1024, step=64, label="Max tokens (OpenAI path)")
|
| 300 |
|
| 301 |
with gr.Row():
|
| 302 |
send = gr.Button("🚀 Generate", variant="primary")
|
|
@@ -306,7 +300,7 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
|
|
| 306 |
chat = gr.Chatbot(
|
| 307 |
label="Conversation",
|
| 308 |
height=420,
|
| 309 |
-
type="messages",
|
| 310 |
avatar_images=(None, None),
|
| 311 |
)
|
| 312 |
|
|
@@ -358,6 +352,6 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
|
|
| 358 |
return [], 0, 0.0, None, ""
|
| 359 |
clear.click(on_clear, outputs=[chat, latency, cost, image, user_message])
|
| 360 |
|
| 361 |
-
# Main
|
| 362 |
if __name__ == "__main__":
|
| 363 |
demo.queue(max_size=64).launch()
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# -----------------------------
|
| 40 |
+
# Lazy Imports
|
| 41 |
# -----------------------------
|
| 42 |
def _lazy_import_openai():
|
| 43 |
try:
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
# -----------------------------
|
| 59 |
+
# Utils
|
| 60 |
# -----------------------------
|
| 61 |
def is_blocked(text: str) -> bool:
|
| 62 |
if not text:
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def pil_to_base64(image) -> str:
|
|
|
|
| 69 |
buffer = BytesIO()
|
| 70 |
image.convert("RGB").save(buffer, format="JPEG", quality=92)
|
| 71 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
def estimate_cost(provider_label: str, model: str, prompt: str, reply: str) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
toks = approx_tokens_from_chars(prompt) + approx_tokens_from_chars(reply)
|
| 80 |
if provider_label.startswith("OpenAI"):
|
| 81 |
+
return round(toks / 1_000_000.0 * 7.5, 4) # illustrative
|
| 82 |
+
return round(toks / 1_000_000.0 * 5.0, 4) # illustrative
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
# -----------------------------
|
|
|
|
| 95 |
max_tokens: int,
|
| 96 |
) -> str:
|
| 97 |
"""
|
| 98 |
+
Calls OpenAI Chat Completions. Auto-switches from `max_tokens` to
|
| 99 |
+
`max_completion_tokens` if the model requires it.
|
| 100 |
"""
|
| 101 |
OpenAI = _lazy_import_openai()
|
| 102 |
client = OpenAI(api_key=api_key)
|
| 103 |
|
| 104 |
+
messages = [{"role": "system", "content": (system_prompt.strip() or SYSTEM_DEFAULT)}]
|
| 105 |
+
messages.extend(history_messages or [])
|
| 106 |
messages.append({"role": "user", "content": user_message})
|
| 107 |
|
| 108 |
+
kwargs = dict(
|
| 109 |
model=(model.strip() or DEFAULT_OPENAI_MODEL),
|
| 110 |
messages=messages,
|
| 111 |
temperature=float(temperature),
|
|
|
|
| 112 |
)
|
| 113 |
+
|
| 114 |
+
# First try with legacy param
|
| 115 |
+
try:
|
| 116 |
+
kwargs["max_tokens"] = int(max_tokens)
|
| 117 |
+
resp = client.chat.completions.create(**kwargs)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
msg = str(e)
|
| 120 |
+
# Auto-retry with new param when model rejects max_tokens
|
| 121 |
+
if "max_tokens" in msg and ("max_completion_tokens" in msg or "Unsupported parameter" in msg):
|
| 122 |
+
kwargs.pop("max_tokens", None)
|
| 123 |
+
kwargs["max_completion_tokens"] = int(max_tokens)
|
| 124 |
+
resp = client.chat.completions.create(**kwargs)
|
| 125 |
+
else:
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
return resp.choices[0].message.content
|
| 129 |
|
| 130 |
|
|
|
|
| 136 |
image=None,
|
| 137 |
temperature: float = 0.4,
|
| 138 |
) -> str:
|
|
|
|
|
|
|
|
|
|
| 139 |
genai = _lazy_import_gemini()
|
| 140 |
genai.configure(api_key=api_key)
|
| 141 |
|
|
|
|
| 155 |
|
| 156 |
parts: List[Any] = [user_message or ""]
|
| 157 |
if image is not None:
|
|
|
|
| 158 |
parts.append(image)
|
| 159 |
|
| 160 |
resp = model_obj.generate_content(parts)
|
| 161 |
|
|
|
|
| 162 |
if hasattr(resp, "text") and resp.text:
|
| 163 |
return resp.text
|
|
|
|
|
|
|
| 164 |
cand = getattr(resp, "candidates", None)
|
| 165 |
if cand and getattr(cand[0], "content", None):
|
| 166 |
parts = getattr(cand[0].content, "parts", None)
|
| 167 |
if parts and hasattr(parts[0], "text"):
|
| 168 |
return parts[0].text
|
|
|
|
| 169 |
return "(No response text returned.)"
|
| 170 |
|
| 171 |
|
|
|
|
| 173 |
# Orchestration
|
| 174 |
# -----------------------------
|
| 175 |
def to_openai_history(gradio_history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
|
|
|
|
|
|
|
|
|
|
| 176 |
oai: List[Dict[str, str]] = []
|
| 177 |
for user_msg, ai_msg in gradio_history or []:
|
| 178 |
if user_msg:
|
|
|
|
| 194 |
max_tokens: int,
|
| 195 |
history: List[Tuple[str, str]],
|
| 196 |
):
|
|
|
|
|
|
|
|
|
|
| 197 |
if not (user_message and user_message.strip()):
|
| 198 |
raise gr.Error("Please enter a prompt (or pick a starter prompt).")
|
| 199 |
if is_blocked(user_message):
|
|
|
|
| 290 |
)
|
| 291 |
with gr.Row():
|
| 292 |
temperature = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Temperature")
|
| 293 |
+
max_tokens = gr.Slider(128, 4096, value=1024, step=64, label="Max completion tokens (OpenAI path)")
|
| 294 |
|
| 295 |
with gr.Row():
|
| 296 |
send = gr.Button("🚀 Generate", variant="primary")
|
|
|
|
| 300 |
chat = gr.Chatbot(
|
| 301 |
label="Conversation",
|
| 302 |
height=420,
|
| 303 |
+
type="messages",
|
| 304 |
avatar_images=(None, None),
|
| 305 |
)
|
| 306 |
|
|
|
|
| 352 |
return [], 0, 0.0, None, ""
|
| 353 |
clear.click(on_clear, outputs=[chat, latency, cost, image, user_message])
|
| 354 |
|
| 355 |
+
# Main
|
| 356 |
if __name__ == "__main__":
|
| 357 |
demo.queue(max_size=64).launch()
|