mie237 commited on
Commit
d8d5fca
·
1 Parent(s): 2a4fc2a

update LLM api

Browse files
Files changed (1) hide show
  1. app.py +125 -35
app.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import os
4
  import json
 
5
  import requests
6
  import gradio as gr
7
  from openai import OpenAI
@@ -18,9 +19,15 @@ def _require_env(name: str) -> str:
18
  return value
19
 
20
  AUDIOGEN_API_URL = _require_env("AUDIOGEN_API_URL")
21
- LLM_BASE_URL = _require_env("LLM_BASE_URL")
 
22
  PROMPT_REFINER_MAX_RETRIES = 3
23
 
 
 
 
 
 
24
  # Special token order and mapping
25
  SPECIAL_TOKEN_ORDER = ["caption", "speech", "sfx", "music", "env", "asr"]
26
  SPECIAL_TOKEN_MAP = {
@@ -82,31 +89,112 @@ def call_audiogen(structured_prompt):
82
  return None, f"Error: {str(e)}"
83
 
84
 
85
- def call_prompt_refiner(user_input, max_retries=PROMPT_REFINER_MAX_RETRIES):
86
- """Call the Prompt Refiner via LLM API.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- Returns a dict with lowercase keys: caption, speech, sfx, music, env, asr.
89
- Retries up to `max_retries` times on JSON / validation errors.
90
- Raises EnvironmentError if required env vars are missing.
91
- Raises RuntimeError on unrecoverable API or repeated failures.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  api_key = os.environ.get("API_KEY")
94
  model_name = os.environ.get("MODEL_NAME")
95
 
96
  if not api_key:
97
  raise EnvironmentError(
98
  "API_KEY environment variable is not set. "
99
- "Please set it before using Auto Mode."
100
  )
101
  if not model_name:
102
  raise EnvironmentError(
103
  "MODEL_NAME environment variable is not set. "
104
- "Please set it before using Auto Mode."
 
 
 
 
 
105
  )
106
 
107
  client = OpenAI(api_key=api_key, base_url=LLM_BASE_URL)
108
-
109
- # Remove the {{user_input}} placeholder line; user text goes as user message
110
  system_content = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}", "").strip()
111
 
112
  last_error = None
@@ -123,43 +211,45 @@ def call_prompt_refiner(user_input, max_retries=PROMPT_REFINER_MAX_RETRIES):
123
  )
124
  raw_content = completion.choices[0].message.content
125
 
126
- # ── JSON decode validation ────────────────────────────────────────
127
- try:
128
- parsed = json.loads(raw_content)
129
- except json.JSONDecodeError as e:
130
- last_error = f"Invalid JSON on attempt {attempt}: {e}"
131
- continue # retry
132
-
133
- # ── Field validation ──────────────────────────────────────────────
134
- # Normalize all keys to lowercase; drop null values
135
- normalized = {
136
- k.lower(): v
137
- for k, v in parsed.items()
138
- if v is not None and str(v).strip()
139
- }
140
- if not normalized.get("caption"):
141
- last_error = f"Missing required 'Caption' field on attempt {attempt}."
142
- continue # retry
143
-
144
- return normalized
145
 
146
  except EnvironmentError:
147
- raise # propagate immediately
148
  except Exception as e:
149
  err_str = str(e).lower()
150
- # Don't retry on auth / key errors
151
  if any(kw in err_str for kw in ("authentication", "api_key", "unauthorized", "403", "401")):
152
  raise RuntimeError(f"Prompt Refiner auth error: {e}") from e
153
  last_error = f"API error on attempt {attempt}: {e}"
154
- if attempt == max_retries:
155
- break
156
 
157
  raise RuntimeError(
158
- f"Prompt Refiner failed after {max_retries} attempt(s). "
159
  f"Last error: {last_error}"
160
  )
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def build_caption_from_refined(refined: dict) -> str:
164
  """Build the full structured prompt string from a refined dict.
165
  This is a convenience wrapper around build_structured_prompt."""
 
2
 
3
  import os
4
  import json
5
+ import gzip
6
  import requests
7
  import gradio as gr
8
  from openai import OpenAI
 
19
  return value
20
 
21
  AUDIOGEN_API_URL = _require_env("AUDIOGEN_API_URL")
22
+ LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "")
23
+ CLAW_API_URL = os.environ.get("CLAW_API_URL", "")
24
  PROMPT_REFINER_MAX_RETRIES = 3
25
 
26
+ # Prompt Refiner calling mode:
27
+ # "claw" — POST plain text to CLAW_API_URL, no auth required (default)
28
+ # "openai" — OpenAI-compatible chat completions via LLM_BASE_URL + API_KEY
29
+ PROMPT_REFINER_MODE = os.environ.get("PROMPT_REFINER_MODE", "claw")
30
+
31
  # Special token order and mapping
32
  SPECIAL_TOKEN_ORDER = ["caption", "speech", "sfx", "music", "env", "asr"]
33
  SPECIAL_TOKEN_MAP = {
 
89
  return None, f"Error: {str(e)}"
90
 
91
 
92
+ def _parse_and_validate(raw_content: str, attempt: int):
93
+ """Parse JSON and validate required 'caption' field. Returns (dict|None, error_str)."""
94
+ try:
95
+ parsed = json.loads(raw_content)
96
+ except json.JSONDecodeError as e:
97
+ return None, f"Invalid JSON on attempt {attempt}: {e}"
98
+
99
+ normalized = {
100
+ k.lower(): v
101
+ for k, v in parsed.items()
102
+ if v is not None and str(v).strip()
103
+ }
104
+ if not normalized.get("caption"):
105
+ return None, f"Missing required 'Caption' field on attempt {attempt}."
106
+ return normalized, None
107
+
108
 
109
+ def _decode_claw_response_json(response: requests.Response) -> dict:
110
+ """Decode CLAW response robustly, including mis-labeled gzip responses."""
111
+ raw_bytes = response.raw.read(decode_content=False)
112
+
113
+ candidates = []
114
+ # 1) Treat as plain utf-8 text first (some responses are plain text but mislabeled)
115
+ candidates.append(raw_bytes.decode("utf-8", errors="replace"))
116
+ # 2) Try gzip decode as fallback when content-encoding is incorrect/mixed
117
+ try:
118
+ candidates.append(gzip.decompress(raw_bytes).decode("utf-8", errors="replace"))
119
+ except Exception:
120
+ pass
121
+
122
+ last_err = None
123
+ for text in candidates:
124
+ try:
125
+ return json.loads(text)
126
+ except Exception as e:
127
+ last_err = e
128
+ raise ValueError(f"Unable to decode CLAW JSON response: {last_err}")
129
+
130
+
131
+ def _call_prompt_refiner_claw(user_input: str, max_retries: int) -> dict:
132
+ """Call Prompt Refiner via CLAW endpoint (no auth required).
133
+ Sends the full prompt template with user input substituted as plain text.
134
  """
135
+ # Substitute user input into the prompt template
136
+ full_prompt = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}", user_input).strip()
137
+
138
+ last_error = None
139
+ for attempt in range(1, max_retries + 1):
140
+ try:
141
+ response = requests.post(
142
+ CLAW_API_URL,
143
+ headers={"Content-Type": "text/plain"},
144
+ data=full_prompt.encode("utf-8"),
145
+ timeout=60,
146
+ stream=True,
147
+ )
148
+ response.raise_for_status()
149
+
150
+ # Response has same structure as OpenAI: choices[0].message.content
151
+ resp_json = _decode_claw_response_json(response)
152
+ raw_content = resp_json["choices"][0]["message"]["content"]
153
+
154
+ result, err = _parse_and_validate(raw_content, attempt)
155
+ if err:
156
+ last_error = err
157
+ continue
158
+ return result
159
+
160
+ except requests.exceptions.HTTPError as e:
161
+ code = e.response.status_code
162
+ raise RuntimeError(f"CLAW API HTTP error {code}: {e.response.reason}") from e
163
+ except requests.exceptions.ConnectionError as e:
164
+ raise RuntimeError(f"CLAW API connection error: {e}") from e
165
+ except requests.exceptions.Timeout:
166
+ last_error = f"CLAW API timed out on attempt {attempt}."
167
+ except Exception as e:
168
+ last_error = f"CLAW API error on attempt {attempt}: {e}"
169
+
170
+ raise RuntimeError(
171
+ f"Prompt Refiner (claw) failed after {max_retries} attempt(s). "
172
+ f"Last error: {last_error}"
173
+ )
174
+
175
+
176
+ def _call_prompt_refiner_openai(user_input: str, max_retries: int) -> dict:
177
+ """Call Prompt Refiner via OpenAI-compatible chat completions endpoint."""
178
  api_key = os.environ.get("API_KEY")
179
  model_name = os.environ.get("MODEL_NAME")
180
 
181
  if not api_key:
182
  raise EnvironmentError(
183
  "API_KEY environment variable is not set. "
184
+ "Please set it before using Auto Mode (openai mode)."
185
  )
186
  if not model_name:
187
  raise EnvironmentError(
188
  "MODEL_NAME environment variable is not set. "
189
+ "Please set it before using Auto Mode (openai mode)."
190
+ )
191
+ if not LLM_BASE_URL:
192
+ raise EnvironmentError(
193
+ "LLM_BASE_URL environment variable is not set. "
194
+ "Please set it before using Auto Mode (openai mode)."
195
  )
196
 
197
  client = OpenAI(api_key=api_key, base_url=LLM_BASE_URL)
 
 
198
  system_content = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}", "").strip()
199
 
200
  last_error = None
 
211
  )
212
  raw_content = completion.choices[0].message.content
213
 
214
+ result, err = _parse_and_validate(raw_content, attempt)
215
+ if err:
216
+ last_error = err
217
+ continue
218
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  except EnvironmentError:
221
+ raise
222
  except Exception as e:
223
  err_str = str(e).lower()
 
224
  if any(kw in err_str for kw in ("authentication", "api_key", "unauthorized", "403", "401")):
225
  raise RuntimeError(f"Prompt Refiner auth error: {e}") from e
226
  last_error = f"API error on attempt {attempt}: {e}"
 
 
227
 
228
  raise RuntimeError(
229
+ f"Prompt Refiner (openai) failed after {max_retries} attempt(s). "
230
  f"Last error: {last_error}"
231
  )
232
 
233
 
234
+ def call_prompt_refiner(user_input, max_retries=PROMPT_REFINER_MAX_RETRIES):
235
+ """Dispatch to the configured Prompt Refiner backend.
236
+
237
+ Mode is controlled by the PROMPT_REFINER_MODE environment variable:
238
+ 'claw' — CLAW plain-text endpoint, no auth required (default)
239
+ 'openai' — OpenAI-compatible chat completions endpoint
240
+ """
241
+ mode = PROMPT_REFINER_MODE.lower()
242
+ if mode == "openai":
243
+ return _call_prompt_refiner_openai(user_input, max_retries)
244
+ elif mode == "claw":
245
+ return _call_prompt_refiner_claw(user_input, max_retries)
246
+ else:
247
+ raise ValueError(
248
+ f"Unknown PROMPT_REFINER_MODE '{mode}'. "
249
+ "Valid values: 'claw' (default), 'openai'."
250
+ )
251
+
252
+
253
  def build_caption_from_refined(refined: dict) -> str:
254
  """Build the full structured prompt string from a refined dict.
255
  This is a convenience wrapper around build_structured_prompt."""