Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import re | |
| import json | |
| import uuid | |
| import shutil | |
| import logging | |
| import base64 | |
| from concurrent.futures import ThreadPoolExecutor | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import rarfile | |
| import zipfile | |
| from google import genai | |
| from google.genai import types | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Persistent storage directory | |
| TASKS_DIR = "data_tasks" | |
| os.makedirs(TASKS_DIR, exist_ok=True) | |
| # --- Utility Functions --- | |
| def parse_srt_time_to_ms(time_str): | |
| try: | |
| if not time_str: return 0 | |
| time, ms = time_str.replace(',', '.').split('.') | |
| hours, minutes, seconds = map(int, time.split(':')) | |
| return (hours * 3600000) + (minutes * 60000) + (seconds * 1000) + int(ms) | |
| except Exception: | |
| return 0 | |
| def parse_filename_to_ms(filename): | |
| match = re.search(r'(\d{1,2})_(\d{2})_(\d{2})_(\d{3})', filename) | |
| if not match: | |
| return None | |
| h, m, s, ms = map(int, match.groups()) | |
| return (h * 3600000) + (m * 60000) + (s * 1000) + ms | |
| # In app.py | |
| def parse_srt(content: str): | |
| """ | |
| Robust Parser: Finds headers first, then slices content between them. | |
| Guarantees that 20 IDs = 20 Items, even if text is empty. | |
| """ | |
| # 1. Normalize line endings | |
| content = content.replace('\r\n', '\n').replace('\r', '\n') | |
| # 2. Find all headers (ID + Time) | |
| # We do NOT try to match text here. We only look for the anchors. | |
| header_pattern = re.compile(r'(\d+)\n(\d{2}:\d{2}:\d{2}[,.]\d{3}\s*-->\s*\d{2}:\d{2}:\d{2}[,.]\d{3})', re.MULTILINE) | |
| matches = list(header_pattern.finditer(content)) | |
| parsed = [] | |
| for i, match in enumerate(matches): | |
| srt_id = match.group(1) | |
| time_range = match.group(2) | |
| # Start matching text immediately after this header | |
| start_index = match.end() | |
| # Stop matching text at the start of the NEXT header (or EOF) | |
| if i + 1 < len(matches): | |
| end_index = matches[i+1].start() | |
| else: | |
| end_index = len(content) | |
| # Extract and clean the text | |
| raw_text = content[start_index:end_index] | |
| text = raw_text.strip() | |
| try: | |
| start_time_str = time_range.split('-->')[0].strip() | |
| start_ms = parse_srt_time_to_ms(start_time_str) | |
| except: | |
| start_ms = 0 | |
| parsed.append({ | |
| "id": srt_id, | |
| "time": time_range, | |
| "startTimeMs": start_ms, | |
| "text": text # This will be "" (empty string) if no text exists, but the item remains! | |
| }) | |
| return parsed | |
| def compress_image(image_bytes, max_width=800, quality=80): | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| img.thumbnail((max_width, max_width), Image.Resampling.LANCZOS) | |
| buffer = io.BytesIO() | |
| img.save(buffer, format="WEBP", quality=quality, method=6) | |
| return buffer.getvalue() | |
| except Exception as e: | |
| logger.error(f"Compression error: {e}") | |
| return None | |
| def process_batch_gemini(api_key, items, model_name): | |
| try: | |
| client = genai.Client(api_key=api_key) | |
| prompt_parts = [ | |
| "You are a Subtitle Quality Control (QC) bot.", | |
| f"I will provide {len(items)} images and the EXPECTED subtitle text for each.", | |
| "Return a JSON array strictly: " | |
| '[{"index": <int>, "detected_text": "<string>", "match": <bool>, "reason": "<string>"}, ...]', | |
| "Return ONLY the JSON. No markdown." | |
| ] | |
| for item in items: | |
| # Handle empty expected text explicitly for the AI | |
| exp_text = item['expected_text'] if item['expected_text'].strip() else "[BLANK/EMPTY]" | |
| prompt_parts.append(f"\n--- Item {item['index']} ---") | |
| prompt_parts.append(f"Expected Text: \"{exp_text}\"") | |
| prompt_parts.append(Image.open(io.BytesIO(item['image_data']))) | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt_parts, | |
| config=types.GenerateContentConfig(response_mime_type="application/json") | |
| ) | |
| text = response.text.replace("```json", "").replace("```", "").strip() | |
| return json.loads(text) | |
| except Exception as e: | |
| logger.error(f"Gemini API Error: {e}") | |
| return None | |
| # --- Endpoints --- | |
| async def analyze_subtitles( | |
| srt_file: UploadFile = File(...), | |
| media_files: list[UploadFile] = File(...), | |
| api_keys: str = Form(...), | |
| batch_size: int = Form(20), | |
| model_name: str = Form("gemini-2.0-flash"), | |
| compression_quality: float = Form(0.7) | |
| ): | |
| task_id = str(uuid.uuid4()) | |
| task_dir = os.path.join(TASKS_DIR, task_id) | |
| os.makedirs(task_dir, exist_ok=True) | |
| should_cleanup = False | |
| try: | |
| pil_quality = max(10, min(100, int(compression_quality * 100))) | |
| # 1. Save and Parse SRT | |
| srt_path = os.path.join(task_dir, "input.srt") | |
| srt_bytes = await srt_file.read() | |
| with open(srt_path, "wb") as f: | |
| f.write(srt_bytes) | |
| srt_data = parse_srt(srt_bytes.decode('utf-8', errors='ignore')) | |
| srt_data.sort(key=lambda x: x['startTimeMs']) | |
| # 2. Extract Media | |
| for file in media_files: | |
| file_path = os.path.join(task_dir, file.filename) | |
| with open(file_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| if file.filename.lower().endswith('.rar'): | |
| with rarfile.RarFile(file_path) as rf: | |
| rf.extractall(task_dir) | |
| elif file.filename.lower().endswith('.zip'): | |
| with zipfile.ZipFile(file_path, 'r') as zf: | |
| zf.extractall(task_dir) | |
| # 3. Pair and Process (shared logic) | |
| return await run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id) | |
| except Exception as e: | |
| logger.error(f"Server Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def retry_analysis( | |
| task_id: str = Form(...), | |
| api_keys: str = Form(...), | |
| batch_size: int = Form(20), | |
| model_name: str = Form("gemini-2.0-flash"), | |
| compression_quality: float = Form(0.7) | |
| ): | |
| task_dir = os.path.join(TASKS_DIR, task_id) | |
| if not os.path.exists(task_dir): | |
| raise HTTPException(status_code=404, detail="Task files not found.") | |
| srt_path = os.path.join(task_dir, "input.srt") | |
| with open(srt_path, "r", encoding="utf-8", errors="ignore") as f: | |
| srt_data = parse_srt(f.read()) | |
| pil_quality = max(10, min(100, int(compression_quality * 100))) | |
| return await run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id) | |
| async def run_core_analysis(task_dir, srt_data, api_keys, batch_size, model_name, pil_quality, task_id): | |
| images = [] | |
| for root, _, files in os.walk(task_dir): | |
| for filename in files: | |
| if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.bmp')): | |
| ms = parse_filename_to_ms(filename) | |
| if ms is not None: | |
| with open(os.path.join(root, filename), "rb") as f: | |
| comp = compress_image(f.read(), quality=pil_quality) | |
| if comp: images.append({"filename": filename, "timeMs": ms, "data": comp}) | |
| images.sort(key=lambda x: x['timeMs']) | |
| pairs = [] | |
| for i, img in enumerate(images): | |
| srt = srt_data[i] if i < len(srt_data) else None | |
| if srt: | |
| thumb = compress_image(img['data'], quality=40, max_width=200) | |
| pairs.append({ | |
| "index": i, "image_data": img['data'], "expected_text": srt['text'], | |
| "srt_id": srt['id'], "srt_time": srt['time'], "filename": img['filename'], | |
| "thumb": f"data:image/webp;base64,{base64.b64encode(thumb).decode()}", | |
| "status": "pending" | |
| }) | |
| keys = [k.strip() for k in api_keys.split('\n') if k.strip()] | |
| results_map = {} | |
| batches = [pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)] | |
| with ThreadPoolExecutor(max_workers=len(keys)) as executor: | |
| futures = [executor.submit(process_batch_gemini, keys[i % len(keys)], b, model_name) for i, b in enumerate(batches)] | |
| for f in futures: | |
| res = f.result() | |
| if res: | |
| for item in res: results_map[item['index']] = item | |
| final_output = [] | |
| any_pending = False | |
| for p in pairs: | |
| res = results_map.get(p['index']) | |
| status = ("match" if res['match'] else "mismatch") if res else "pending" | |
| if status == "pending": any_pending = True | |
| final_output.append({ | |
| "id": p['index'], "status": status, "expected": p['expected_text'], | |
| "detected": res.get('detected_text', '') if res else "", | |
| "reason": res.get('reason', '') if res else "", | |
| "thumb": p['thumb'], "filename": p['filename'], "srt_id": p['srt_id'] | |
| }) | |
| if not any_pending: | |
| shutil.rmtree(task_dir) | |
| return {"status": "success", "results": final_output} | |
| return {"status": "partial", "task_id": task_id, "results": final_output} | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") |