ZENLLC commited on
Commit
3c66ead
·
verified ·
1 Parent(s): e281c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -33
app.py CHANGED
@@ -37,7 +37,7 @@ BLOCKLIST = ["<script", "</script>", "{{", "}}"]
37
 
38
 
39
  # -----------------------------
40
- # Lazy Imports (so the app boots even if a provider SDK isn't installed)
41
  # -----------------------------
42
  def _lazy_import_openai():
43
  try:
@@ -56,7 +56,7 @@ def _lazy_import_gemini():
56
 
57
 
58
  # -----------------------------
59
- # Utility helpers
60
  # -----------------------------
61
  def is_blocked(text: str) -> bool:
62
  if not text:
@@ -66,7 +66,6 @@ def is_blocked(text: str) -> bool:
66
 
67
 
68
  def pil_to_base64(image) -> str:
69
- """Convert PIL image to base64 JPEG for potential API usage."""
70
  buffer = BytesIO()
71
  image.convert("RGB").save(buffer, format="JPEG", quality=92)
72
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
@@ -77,16 +76,10 @@ def approx_tokens_from_chars(text: str) -> int:
77
 
78
 
79
  def estimate_cost(provider_label: str, model: str, prompt: str, reply: str) -> float:
80
- """
81
- Super rough cost estimator. Tune to your account reality.
82
- Using blended, illustrative CPMs for demo purposes only.
83
- """
84
  toks = approx_tokens_from_chars(prompt) + approx_tokens_from_chars(reply)
85
  if provider_label.startswith("OpenAI"):
86
- # Example: blended $7.5 / 1M tokens
87
- return round(toks / 1_000_000.0 * 7.5, 4)
88
- # Google/Gemini placeholder: $5 / 1M tokens
89
- return round(toks / 1_000_000.0 * 5.0, 4)
90
 
91
 
92
  # -----------------------------
@@ -102,21 +95,36 @@ def call_openai_chat(
102
  max_tokens: int,
103
  ) -> str:
104
  """
105
- Calls OpenAI Chat Completions. History uses OpenAI role/content format.
 
106
  """
107
  OpenAI = _lazy_import_openai()
108
  client = OpenAI(api_key=api_key)
109
 
110
- messages = [{"role": "system", "content": system_prompt.strip() or SYSTEM_DEFAULT}]
111
- messages.extend(history_messages)
112
  messages.append({"role": "user", "content": user_message})
113
 
114
- resp = client.chat.completions.create(
115
  model=(model.strip() or DEFAULT_OPENAI_MODEL),
116
  messages=messages,
117
  temperature=float(temperature),
118
- max_tokens=int(max_tokens),
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  return resp.choices[0].message.content
121
 
122
 
@@ -128,9 +136,6 @@ def call_gemini_generate(
128
  image=None,
129
  temperature: float = 0.4,
130
  ) -> str:
131
- """
132
- Calls Gemini/Nano-Banana. Supports optional image (PIL) as part.
133
- """
134
  genai = _lazy_import_gemini()
135
  genai.configure(api_key=api_key)
136
 
@@ -150,22 +155,17 @@ def call_gemini_generate(
150
 
151
  parts: List[Any] = [user_message or ""]
152
  if image is not None:
153
- # google-generativeai accepts PIL image directly as a part
154
  parts.append(image)
155
 
156
  resp = model_obj.generate_content(parts)
157
 
158
- # Prefer .text if available
159
  if hasattr(resp, "text") and resp.text:
160
  return resp.text
161
-
162
- # Fallback: candidates/parts
163
  cand = getattr(resp, "candidates", None)
164
  if cand and getattr(cand[0], "content", None):
165
  parts = getattr(cand[0].content, "parts", None)
166
  if parts and hasattr(parts[0], "text"):
167
  return parts[0].text
168
-
169
  return "(No response text returned.)"
170
 
171
 
@@ -173,9 +173,6 @@ def call_gemini_generate(
173
  # Orchestration
174
  # -----------------------------
175
  def to_openai_history(gradio_history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
176
- """
177
- Convert Gradio Chatbot history ([(user, assistant), ...]) to OpenAI messages.
178
- """
179
  oai: List[Dict[str, str]] = []
180
  for user_msg, ai_msg in gradio_history or []:
181
  if user_msg:
@@ -197,9 +194,6 @@ def infer(
197
  max_tokens: int,
198
  history: List[Tuple[str, str]],
199
  ):
200
- """
201
- Main entry: routes to the chosen provider, returns updated chat, latency, cost.
202
- """
203
  if not (user_message and user_message.strip()):
204
  raise gr.Error("Please enter a prompt (or pick a starter prompt).")
205
  if is_blocked(user_message):
@@ -296,7 +290,7 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
296
  )
297
  with gr.Row():
298
  temperature = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Temperature")
299
- max_tokens = gr.Slider(128, 4096, value=1024, step=64, label="Max tokens (OpenAI path)")
300
 
301
  with gr.Row():
302
  send = gr.Button("🚀 Generate", variant="primary")
@@ -306,7 +300,7 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
306
  chat = gr.Chatbot(
307
  label="Conversation",
308
  height=420,
309
- type="messages", # avoids deprecation in Gradio 5.x
310
  avatar_images=(None, None),
311
  )
312
 
@@ -358,6 +352,6 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft(), title="ZEN Dual-Engine
358
  return [], 0, 0.0, None, ""
359
  clear.click(on_clear, outputs=[chat, latency, cost, image, user_message])
360
 
361
- # Main (patched: no concurrency_count in Gradio 5.49.1)
362
  if __name__ == "__main__":
363
  demo.queue(max_size=64).launch()
 
37
 
38
 
39
  # -----------------------------
40
+ # Lazy Imports
41
  # -----------------------------
42
  def _lazy_import_openai():
43
  try:
 
56
 
57
 
58
  # -----------------------------
59
+ # Utils
60
  # -----------------------------
61
  def is_blocked(text: str) -> bool:
62
  if not text:
 
66
 
67
 
68
  def pil_to_base64(image) -> str:
 
69
  buffer = BytesIO()
70
  image.convert("RGB").save(buffer, format="JPEG", quality=92)
71
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
 
76
 
77
 
78
  def estimate_cost(provider_label: str, model: str, prompt: str, reply: str) -> float:
 
 
 
 
79
  toks = approx_tokens_from_chars(prompt) + approx_tokens_from_chars(reply)
80
  if provider_label.startswith("OpenAI"):
81
+ return round(toks / 1_000_000.0 * 7.5, 4) # illustrative
82
+ return round(toks / 1_000_000.0 * 5.0, 4) # illustrative
 
 
83
 
84
 
85
  # -----------------------------
 
95
  max_tokens: int,
96
  ) -> str:
97
  """
98
+ Calls OpenAI Chat Completions. Auto-switches from `max_tokens` to
99
+ `max_completion_tokens` if the model requires it.
100
  """
101
  OpenAI = _lazy_import_openai()
102
  client = OpenAI(api_key=api_key)
103
 
104
+ messages = [{"role": "system", "content": (system_prompt.strip() or SYSTEM_DEFAULT)}]
105
+ messages.extend(history_messages or [])
106
  messages.append({"role": "user", "content": user_message})
107
 
108
+ kwargs = dict(
109
  model=(model.strip() or DEFAULT_OPENAI_MODEL),
110
  messages=messages,
111
  temperature=float(temperature),
 
112
  )
113
+
114
+ # First try with legacy param
115
+ try:
116
+ kwargs["max_tokens"] = int(max_tokens)
117
+ resp = client.chat.completions.create(**kwargs)
118
+ except Exception as e:
119
+ msg = str(e)
120
+ # Auto-retry with new param when model rejects max_tokens
121
+ if "max_tokens" in msg and ("max_completion_tokens" in msg or "Unsupported parameter" in msg):
122
+ kwargs.pop("max_tokens", None)
123
+ kwargs["max_completion_tokens"] = int(max_tokens)
124
+ resp = client.chat.completions.create(**kwargs)
125
+ else:
126
+ raise
127
+
128
  return resp.choices[0].message.content
129
 
130
 
 
136
  image=None,
137
  temperature: float = 0.4,
138
  ) -> str:
 
 
 
139
  genai = _lazy_import_gemini()
140
  genai.configure(api_key=api_key)
141
 
 
155
 
156
  parts: List[Any] = [user_message or ""]
157
  if image is not None:
 
158
  parts.append(image)
159
 
160
  resp = model_obj.generate_content(parts)
161
 
 
162
  if hasattr(resp, "text") and resp.text:
163
  return resp.text
 
 
164
  cand = getattr(resp, "candidates", None)
165
  if cand and getattr(cand[0], "content", None):
166
  parts = getattr(cand[0].content, "parts", None)
167
  if parts and hasattr(parts[0], "text"):
168
  return parts[0].text
 
169
  return "(No response text returned.)"
170
 
171
 
 
173
  # Orchestration
174
  # -----------------------------
175
  def to_openai_history(gradio_history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
 
 
 
176
  oai: List[Dict[str, str]] = []
177
  for user_msg, ai_msg in gradio_history or []:
178
  if user_msg:
 
194
  max_tokens: int,
195
  history: List[Tuple[str, str]],
196
  ):
 
 
 
197
  if not (user_message and user_message.strip()):
198
  raise gr.Error("Please enter a prompt (or pick a starter prompt).")
199
  if is_blocked(user_message):
 
290
  )
291
  with gr.Row():
292
  temperature = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Temperature")
293
+ max_tokens = gr.Slider(128, 4096, value=1024, step=64, label="Max completion tokens (OpenAI path)")
294
 
295
  with gr.Row():
296
  send = gr.Button("🚀 Generate", variant="primary")
 
300
  chat = gr.Chatbot(
301
  label="Conversation",
302
  height=420,
303
+ type="messages",
304
  avatar_images=(None, None),
305
  )
306
 
 
352
  return [], 0, 0.0, None, ""
353
  clear.click(on_clear, outputs=[chat, latency, cost, image, user_message])
354
 
355
+ # Main
356
  if __name__ == "__main__":
357
  demo.queue(max_size=64).launch()