bigbossmonster commited on
Commit
e74e277
·
verified ·
1 Parent(s): c24bff4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -199
app.py CHANGED
@@ -2,12 +2,12 @@ import os
2
  import io
3
  import re
4
  import json
5
- import tempfile
6
  import shutil
7
  import logging
8
  import base64
9
  from concurrent.futures import ThreadPoolExecutor
10
- from PIL import Image, ImageOps
11
 
12
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
13
  from fastapi.staticfiles import StaticFiles
@@ -15,7 +15,6 @@ from fastapi.middleware.cors import CORSMiddleware
15
  import rarfile
16
  import zipfile
17
 
18
- # --- MIGRATION: New SDK Imports ---
19
  from google import genai
20
  from google.genai import types
21
 
@@ -33,6 +32,10 @@ app.add_middleware(
33
  allow_headers=["*"],
34
  )
35
 
 
 
 
 
36
  # --- Utility Functions ---
37
 
38
  def parse_srt_time_to_ms(time_str):
@@ -52,28 +55,33 @@ def parse_filename_to_ms(filename):
52
  return (h * 3600000) + (m * 60000) + (s * 1000) + ms
53
 
54
  def parse_srt(content: str):
 
 
 
 
 
55
  # Normalize line endings
56
  content = content.replace('\r\n', '\n').replace('\r', '\n')
57
- # Split by double newline (standard SRT block separator)
58
- blocks = re.split(r'\n\n+', content.strip())
 
 
 
 
 
 
 
 
 
59
 
60
  parsed = []
61
- for block in blocks:
62
- lines = [l.strip() for l in block.split('\n') if l.strip()]
63
- if len(lines) < 2:
64
- continue
65
-
66
- srt_id = lines[0]
67
- time_range = lines[1]
68
 
69
- # Check if there is actually text after the timestamp
70
- # If there are lines after index 1, join them; otherwise, it's a blank sub
71
- if len(lines) > 2:
72
- text = "\n".join(lines[2:])
73
- else:
74
- text = "" # Explicitly blank
75
-
76
  try:
 
77
  start_time_str = time_range.split('-->')[0].strip()
78
  start_ms = parse_srt_time_to_ms(start_time_str)
79
 
@@ -84,117 +92,52 @@ def parse_srt(content: str):
84
  "text": text
85
  })
86
  except Exception as e:
87
- logger.warning(f"Skipping malformed SRT block: {block[:50]}... Error: {e}")
88
 
89
  return parsed
90
 
91
-
92
  def compress_image(image_bytes, max_width=800, quality=80):
93
- """
94
- Compresses an image to WebP (best) or optimized JPEG.
95
- """
96
  try:
97
  img = Image.open(io.BytesIO(image_bytes))
98
-
99
- # 1. Efficient Resize (Using thumbnail prevents upscaling artifacts)
100
  img.thumbnail((max_width, max_width), Image.Resampling.LANCZOS)
101
-
102
  buffer = io.BytesIO()
103
-
104
- # 2. Try WebP first (Best quality/size ratio)
105
- use_webp = True
106
-
107
- if use_webp:
108
- # method=6 is the strongest compression algo for WebP
109
- img.save(buffer, format="WEBP", quality=quality, method=6)
110
- else:
111
- # Fallback: Optimized JPEG
112
- # Handle transparency (paste on white)
113
- if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
114
- background = Image.new('RGB', img.size, (255, 255, 255))
115
- # Handle paletted images with transparency
116
- if img.mode == 'P':
117
- img = img.convert('RGBA')
118
- background.paste(img, mask=img.split()[3])
119
- img = background
120
- elif img.mode != 'RGB':
121
- img = img.convert('RGB')
122
-
123
- img.save(
124
- buffer,
125
- format="JPEG",
126
- quality=quality,
127
- optimize=True,
128
- progressive=True,
129
- subsampling=0
130
- )
131
-
132
  return buffer.getvalue()
133
-
134
  except Exception as e:
135
- logger.error(f"Image compression failed: {e}")
136
  return None
137
 
138
- # --- MIGRATION: Updated Gemini Processing Function ---
139
  def process_batch_gemini(api_key, items, model_name):
140
  try:
141
- # 1. Instantiate the Client (New SDK pattern)
142
- # This replaces genai.configure()
143
  client = genai.Client(api_key=api_key)
144
-
145
  prompt_parts = [
146
  "You are a Subtitle Quality Control (QC) bot.",
147
  f"I will provide {len(items)} images and the EXPECTED subtitle text for each.",
148
- "Return a JSON array strictly following this schema:",
149
  '[{"index": <int>, "detected_text": "<string>", "match": <bool>, "reason": "<string>"}, ...]',
150
  "Return ONLY the JSON. No markdown."
151
  ]
152
 
153
  for item in items:
 
 
154
  prompt_parts.append(f"\n--- Item {item['index']} ---")
155
- prompt_parts.append(f"Index: {item['index']}")
156
- prompt_parts.append(f"Expected Text: \"{item['expected_text']}\"")
157
- prompt_parts.append(f"Image:")
158
-
159
- # The new SDK handles PIL images directly in the contents list just like the old one
160
- img = Image.open(io.BytesIO(item['image_data']))
161
- prompt_parts.append(img)
162
 
163
- # 2. Call generate_content via the client
164
  response = client.models.generate_content(
165
  model=model_name,
166
  contents=prompt_parts,
167
- config=types.GenerateContentConfig(
168
- response_mime_type="application/json"
169
- )
170
  )
171
 
172
  text = response.text.replace("```json", "").replace("```", "").strip()
173
-
174
- try:
175
- return json.loads(text)
176
- except json.JSONDecodeError as e:
177
- # Handle Truncated JSON (Output Token Limit Exceeded)
178
- logger.warning(f"JSON Parse Error (likely truncated response): {e}. Attempting repair...")
179
-
180
- # Repair Strategy: Find the last closing brace '}', discard everything after, and close the array ']'
181
- last_object_idx = text.rfind("}")
182
- if last_object_idx != -1:
183
- repaired_text = text[:last_object_idx+1] + "]"
184
- try:
185
- repaired_data = json.loads(repaired_text)
186
- logger.info(f"Successfully repaired JSON. Recovered {len(repaired_data)}/{len(items)} items.")
187
- return repaired_data
188
- except json.JSONDecodeError:
189
- logger.error("JSON repair failed.")
190
-
191
- return None # Fail gracefully if repair is impossible
192
-
193
  except Exception as e:
194
- logger.error(f"Gemini API Error with key ...{api_key[-4:]}: {e}")
195
  return None
196
 
197
- # --- Main Endpoint ---
198
 
199
  @app.post("/api/analyze")
200
  async def analyze_subtitles(
@@ -202,129 +145,119 @@ async def analyze_subtitles(
202
  media_files: list[UploadFile] = File(...),
203
  api_keys: str = Form(...),
204
  batch_size: int = Form(20),
205
- model_name: str = Form("gemini-2.0-flash"), # Updated default model hint
206
  compression_quality: float = Form(0.7)
207
  ):
208
- temp_dir = tempfile.mkdtemp()
 
 
 
 
 
209
  try:
210
- # Convert float quality (0.1-1.0) to integer (10-100) for PIL
211
  pil_quality = max(10, min(100, int(compression_quality * 100)))
212
 
213
- # 1. Read SRT
214
- srt_content = (await srt_file.read()).decode('utf-8', errors='ignore')
215
- srt_data = parse_srt(srt_content)
 
 
 
 
216
  srt_data.sort(key=lambda x: x['startTimeMs'])
217
 
218
- # 2. Process Media
219
- images = []
220
  for file in media_files:
221
- file_path = os.path.join(temp_dir, file.filename)
222
  with open(file_path, "wb") as f:
223
  shutil.copyfileobj(file.file, f)
224
 
225
  if file.filename.lower().endswith('.rar'):
226
- try:
227
- with rarfile.RarFile(file_path) as rf:
228
- rf.extractall(temp_dir)
229
- except rarfile.RarCannotExec:
230
- raise HTTPException(status_code=500, detail="Unrar executable not found in container.")
231
  elif file.filename.lower().endswith('.zip'):
232
  with zipfile.ZipFile(file_path, 'r') as zf:
233
- zf.extractall(temp_dir)
234
-
235
- for root, _, files in os.walk(temp_dir):
236
- for filename in files:
237
- if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.bmp')):
238
- full_path = os.path.join(root, filename)
239
- ms = parse_filename_to_ms(filename)
240
- if ms is not None:
241
- with open(full_path, "rb") as f:
242
- raw_bytes = f.read()
243
- compressed = compress_image(raw_bytes, quality=pil_quality)
244
- if compressed:
245
- images.append({
246
- "filename": filename,
247
- "timeMs": ms,
248
- "data": compressed
249
- })
250
-
251
- images.sort(key=lambda x: x['timeMs'])
252
-
253
- # 3. Pair
254
- pairs = []
255
- for i in range(len(images)):
256
- img = images[i]
257
- srt = srt_data[i] if i < len(srt_data) else None
258
-
259
- if srt:
260
- # Create Thumbnail (lower quality for UI speed)
261
- thumb_bytes = compress_image(img['data'], quality=50, max_width=300)
262
- thumb_b64 = base64.b64encode(thumb_bytes).decode('utf-8')
263
-
264
- pairs.append({
265
- "index": i,
266
- "image_data": img['data'],
267
- "expected_text": srt['text'],
268
- "srt_id": srt['id'],
269
- "srt_time": srt['time'],
270
- "filename": img['filename'],
271
- "thumb": f"data:image/jpeg;base64,{thumb_b64}",
272
- "status": "pending"
273
- })
274
-
275
- if not pairs:
276
- return {"status": "error", "message": "No valid image/subtitle pairs found."}
277
-
278
- # 4. Process Gemini
279
- keys = [k.strip() for k in api_keys.split('\n') if k.strip()]
280
- if not keys:
281
- raise HTTPException(status_code=400, detail="No API Keys provided")
282
-
283
- results_map = {}
284
- batches = [pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)]
285
 
286
- def worker(batch_idx, batch):
287
- key = keys[batch_idx % len(keys)]
288
- return process_batch_gemini(key, batch, model_name)
289
 
290
- with ThreadPoolExecutor(max_workers=len(keys)) as executor:
291
- futures = [executor.submit(worker, i, b) for i, b in enumerate(batches)]
292
- for future in futures:
293
- res = future.result()
294
- if res:
295
- for item in res:
296
- results_map[item['index']] = item
297
 
298
- # 5. Build Output
299
- final_output = []
300
- for p in pairs:
301
- analysis = results_map.get(p['index'])
302
- status = "pending"
303
- reason = ""
304
- detected = ""
305
- if analysis:
306
- status = "match" if analysis['match'] else "mismatch"
307
- reason = analysis.get('reason', '')
308
- detected = analysis.get('detected_text', '')
309
 
310
- final_output.append({
311
- "id": p['index'],
312
- "filename": p['filename'],
313
- "thumb": p['thumb'],
314
- "expected": p['expected_text'],
315
- "detected": detected,
316
- "status": status,
317
- "reason": reason,
318
- "srt_id": p['srt_id'],
319
- "srt_time": p['srt_time']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  })
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  return {"status": "success", "results": final_output}
323
-
324
- except Exception as e:
325
- logger.error(f"Server Error: {e}")
326
- raise HTTPException(status_code=500, detail=str(e))
327
- finally:
328
- shutil.rmtree(temp_dir)
329
 
330
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
2
  import io
3
  import re
4
  import json
5
+ import uuid
6
  import shutil
7
  import logging
8
  import base64
9
  from concurrent.futures import ThreadPoolExecutor
10
+ from PIL import Image
11
 
12
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
13
  from fastapi.staticfiles import StaticFiles
 
15
  import rarfile
16
  import zipfile
17
 
 
18
  from google import genai
19
  from google.genai import types
20
 
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # Persistent storage directory
36
+ TASKS_DIR = "data_tasks"
37
+ os.makedirs(TASKS_DIR, exist_ok=True)
38
+
39
  # --- Utility Functions ---
40
 
41
  def parse_srt_time_to_ms(time_str):
 
55
  return (h * 3600000) + (m * 60000) + (s * 1000) + ms
56
 
57
  def parse_srt(content: str):
58
+ """
59
+ Robust Regex Parser for SRT.
60
+ Handles blank subtitles and inconsistent newlines by searching for patterns
61
+ rather than splitting by newlines.
62
+ """
63
  # Normalize line endings
64
  content = content.replace('\r\n', '\n').replace('\r', '\n')
65
+
66
+ # Pattern explanation:
67
+ # (\d+) -> Group 1: ID
68
+ # \n -> Newline
69
+ # (\d{2}:\d{2}:.*) -> Group 2: Timestamp line
70
+ # \n -> Newline
71
+ # (.*?) -> Group 3: Subtitle text (non-greedy)
72
+ # (?=\n\d+\n|\Z) -> Lookahead: Stop when we see the next ID or end of file
73
+ pattern = re.compile(r'(\d+)\n(\d{2}:\d{2}:\d{2}[,.]\d{3}\s*-->\s*\d{2}:\d{2}:\d{2}[,.]\d{3})\n(.*?)(?=\n\d+\n|\Z)', re.DOTALL)
74
+
75
+ matches = pattern.findall(content)
76
 
77
  parsed = []
78
+ for m in matches:
79
+ srt_id = m[0].strip()
80
+ time_range = m[1].strip()
81
+ text = m[2].strip() # This will correctly be "" if the line is empty
 
 
 
82
 
 
 
 
 
 
 
 
83
  try:
84
+ # Extract start time for sorting
85
  start_time_str = time_range.split('-->')[0].strip()
86
  start_ms = parse_srt_time_to_ms(start_time_str)
87
 
 
92
  "text": text
93
  })
94
  except Exception as e:
95
+ logger.warning(f"Error parsing block {srt_id}: {e}")
96
 
97
  return parsed
98
 
 
99
  def compress_image(image_bytes, max_width=800, quality=80):
 
 
 
100
  try:
101
  img = Image.open(io.BytesIO(image_bytes))
 
 
102
  img.thumbnail((max_width, max_width), Image.Resampling.LANCZOS)
 
103
  buffer = io.BytesIO()
104
+ img.save(buffer, format="WEBP", quality=quality, method=6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return buffer.getvalue()
 
106
  except Exception as e:
107
+ logger.error(f"Compression error: {e}")
108
  return None
109
 
 
110
  def process_batch_gemini(api_key, items, model_name):
111
  try:
 
 
112
  client = genai.Client(api_key=api_key)
 
113
  prompt_parts = [
114
  "You are a Subtitle Quality Control (QC) bot.",
115
  f"I will provide {len(items)} images and the EXPECTED subtitle text for each.",
116
+ "Return a JSON array strictly: "
117
  '[{"index": <int>, "detected_text": "<string>", "match": <bool>, "reason": "<string>"}, ...]',
118
  "Return ONLY the JSON. No markdown."
119
  ]
120
 
121
  for item in items:
122
+ # Handle empty expected text explicitly for the AI
123
+ exp_text = item['expected_text'] if item['expected_text'].strip() else "[BLANK/EMPTY]"
124
  prompt_parts.append(f"\n--- Item {item['index']} ---")
125
+ prompt_parts.append(f"Expected Text: \"{exp_text}\"")
126
+ prompt_parts.append(Image.open(io.BytesIO(item['image_data'])))
 
 
 
 
 
127
 
 
128
  response = client.models.generate_content(
129
  model=model_name,
130
  contents=prompt_parts,
131
+ config=types.GenerateContentConfig(response_mime_type="application/json")
 
 
132
  )
133
 
134
  text = response.text.replace("```json", "").replace("```", "").strip()
135
+ return json.loads(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
+ logger.error(f"Gemini API Error: {e}")
138
  return None
139
 
140
+ # --- Endpoints ---
141
 
142
  @app.post("/api/analyze")
143
  async def analyze_subtitles(
 
145
  media_files: list[UploadFile] = File(...),
146
  api_keys: str = Form(...),
147
  batch_size: int = Form(20),
148
+ model_name: str = Form("gemini-2.0-flash"),
149
  compression_quality: float = Form(0.7)
150
  ):
151
+ task_id = str(uuid.uuid4())
152
+ task_dir = os.path.join(TASKS_DIR, task_id)
153
+ os.makedirs(task_dir, exist_ok=True)
154
+
155
+ should_cleanup = False
156
+
157
  try:
 
158
  pil_quality = max(10, min(100, int(compression_quality * 100)))
159
 
160
+ # 1. Save and Parse SRT
161
+ srt_path = os.path.join(task_dir, "input.srt")
162
+ srt_bytes = await srt_file.read()
163
+ with open(srt_path, "wb") as f:
164
+ f.write(srt_bytes)
165
+
166
+ srt_data = parse_srt(srt_bytes.decode('utf-8', errors='ignore'))
167
  srt_data.sort(key=lambda x: x['startTimeMs'])
168
 
169
+ # 2. Extract Media
 
170
  for file in media_files:
171
+ file_path = os.path.join(task_dir, file.filename)
172
  with open(file_path, "wb") as f:
173
  shutil.copyfileobj(file.file, f)
174
 
175
  if file.filename.lower().endswith('.rar'):
176
+ with rarfile.RarFile(file_path) as rf:
177
+ rf.extractall(task_dir)
 
 
 
178
  elif file.filename.lower().endswith('.zip'):
179
  with zipfile.ZipFile(file_path, 'r') as zf:
180
+ zf.extractall(task_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ # 3. Pair and Process (shared logic)
183
+ return await run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id)
 
184
 
185
+ except Exception as e:
186
+ logger.error(f"Server Error: {e}")
187
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
188
 
189
+ @app.post("/api/retry")
190
+ async def retry_analysis(
191
+ task_id: str = Form(...),
192
+ api_keys: str = Form(...),
193
+ batch_size: int = Form(20),
194
+ model_name: str = Form("gemini-2.0-flash"),
195
+ compression_quality: float = Form(0.7)
196
+ ):
197
+ task_dir = os.path.join(TASKS_DIR, task_id)
198
+ if not os.path.exists(task_dir):
199
+ raise HTTPException(status_code=404, detail="Task files not found.")
200
 
201
+ srt_path = os.path.join(task_dir, "input.srt")
202
+ with open(srt_path, "r", encoding="utf-8", errors="ignore") as f:
203
+ srt_data = parse_srt(f.read())
204
+
205
+ pil_quality = max(10, min(100, int(compression_quality * 100)))
206
+ return await run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id)
207
+
208
+ async def run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id):
209
+ images = []
210
+ for root, _, files in os.walk(task_dir):
211
+ for filename in files:
212
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.bmp')):
213
+ ms = parse_filename_to_ms(filename)
214
+ if ms is not None:
215
+ with open(os.path.join(root, filename), "rb") as f:
216
+ comp = compress_image(f.read(), quality=pil_quality)
217
+ if comp: images.append({"filename": filename, "timeMs": ms, "data": comp})
218
+
219
+ images.sort(key=lambda x: x['timeMs'])
220
+
221
+ pairs = []
222
+ for i, img in enumerate(images):
223
+ srt = srt_data[i] if i < len(srt_data) else None
224
+ if srt:
225
+ thumb = compress_image(img['data'], quality=40, max_width=200)
226
+ pairs.append({
227
+ "index": i, "image_data": img['data'], "expected_text": srt['text'],
228
+ "srt_id": srt['id'], "srt_time": srt['time'], "filename": img['filename'],
229
+ "thumb": f"data:image/webp;base64,{base64.b64encode(thumb).decode()}",
230
+ "status": "pending"
231
  })
232
 
233
+ keys = [k.strip() for k in api_keys.split('\n') if k.strip()]
234
+ results_map = {}
235
+ batches = [pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)]
236
+
237
+ with ThreadPoolExecutor(max_workers=len(keys)) as executor:
238
+ futures = [executor.submit(process_batch_gemini, keys[i % len(keys)], b, model_name) for i, b in enumerate(batches)]
239
+ for f in futures:
240
+ res = f.result()
241
+ if res:
242
+ for item in res: results_map[item['index']] = item
243
+
244
+ final_output = []
245
+ any_pending = False
246
+ for p in pairs:
247
+ res = results_map.get(p['index'])
248
+ status = ("match" if res['match'] else "mismatch") if res else "pending"
249
+ if status == "pending": any_pending = True
250
+ final_output.append({
251
+ "id": p['index'], "status": status, "expected": p['expected_text'],
252
+ "detected": res.get('detected_text', '') if res else "",
253
+ "reason": res.get('reason', '') if res else "",
254
+ "thumb": p['thumb'], "filename": p['filename'], "srt_id": p['srt_id']
255
+ })
256
+
257
+ if not any_pending:
258
+ shutil.rmtree(task_dir)
259
  return {"status": "success", "results": final_output}
260
+
261
+ return {"status": "partial", "task_id": task_id, "results": final_output}
 
 
 
 
262
 
263
  app.mount("/", StaticFiles(directory="static", html=True), name="static")