Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|
| 10 |
|
| 11 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 12 |
from fastapi.staticfiles import StaticFiles
|
| 13 |
-
from fastapi.responses import FileResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
from PIL import Image
|
| 16 |
import rarfile
|
|
@@ -23,7 +22,6 @@ logger = logging.getLogger(__name__)
|
|
| 23 |
|
| 24 |
app = FastAPI()
|
| 25 |
|
| 26 |
-
# 1. Enable CORS (good practice, though less critical if serving static from same origin)
|
| 27 |
app.add_middleware(
|
| 28 |
CORSMiddleware,
|
| 29 |
allow_origins=["*"],
|
|
@@ -32,7 +30,8 @@ app.add_middleware(
|
|
| 32 |
allow_headers=["*"],
|
| 33 |
)
|
| 34 |
|
| 35 |
-
#
|
|
|
|
| 36 |
def parse_srt_time_to_ms(time_str):
|
| 37 |
try:
|
| 38 |
if not time_str: return 0
|
|
@@ -80,16 +79,17 @@ def compress_image(image_bytes, quality=70, max_width=800):
|
|
| 80 |
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
| 81 |
|
| 82 |
buffer = io.BytesIO()
|
| 83 |
-
|
|
|
|
| 84 |
return buffer.getvalue()
|
| 85 |
except Exception as e:
|
| 86 |
logger.error(f"Image compression failed: {e}")
|
| 87 |
return None
|
| 88 |
|
| 89 |
-
def process_batch_gemini(api_key, items):
|
| 90 |
try:
|
| 91 |
genai.configure(api_key=api_key)
|
| 92 |
-
model = genai.GenerativeModel(
|
| 93 |
|
| 94 |
prompt_parts = [
|
| 95 |
"You are a Subtitle Quality Control (QC) bot.",
|
|
@@ -114,31 +114,35 @@ def process_batch_gemini(api_key, items):
|
|
| 114 |
logger.error(f"Gemini API Error with key ...{api_key[-4:]}: {e}")
|
| 115 |
return None
|
| 116 |
|
| 117 |
-
#
|
|
|
|
| 118 |
@app.post("/api/analyze")
|
| 119 |
async def analyze_subtitles(
|
| 120 |
srt_file: UploadFile = File(...),
|
| 121 |
media_files: list[UploadFile] = File(...),
|
| 122 |
api_keys: str = Form(...),
|
| 123 |
-
batch_size: int = Form(20)
|
|
|
|
|
|
|
| 124 |
):
|
| 125 |
temp_dir = tempfile.mkdtemp()
|
| 126 |
try:
|
| 127 |
-
#
|
|
|
|
|
|
|
|
|
|
| 128 |
srt_content = (await srt_file.read()).decode('utf-8', errors='ignore')
|
| 129 |
srt_data = parse_srt(srt_content)
|
| 130 |
srt_data.sort(key=lambda x: x['startTimeMs'])
|
| 131 |
|
| 132 |
-
# Process Media
|
| 133 |
images = []
|
| 134 |
-
|
| 135 |
for file in media_files:
|
| 136 |
file_path = os.path.join(temp_dir, file.filename)
|
| 137 |
with open(file_path, "wb") as f:
|
| 138 |
shutil.copyfileobj(file.file, f)
|
| 139 |
|
| 140 |
if file.filename.lower().endswith('.rar'):
|
| 141 |
-
# Docker guarantees 'unrar' is installed
|
| 142 |
try:
|
| 143 |
with rarfile.RarFile(file_path) as rf:
|
| 144 |
rf.extractall(temp_dir)
|
|
@@ -156,7 +160,7 @@ async def analyze_subtitles(
|
|
| 156 |
if ms is not None:
|
| 157 |
with open(full_path, "rb") as f:
|
| 158 |
raw_bytes = f.read()
|
| 159 |
-
compressed = compress_image(raw_bytes)
|
| 160 |
if compressed:
|
| 161 |
images.append({
|
| 162 |
"filename": filename,
|
|
@@ -166,12 +170,14 @@ async def analyze_subtitles(
|
|
| 166 |
|
| 167 |
images.sort(key=lambda x: x['timeMs'])
|
| 168 |
|
|
|
|
| 169 |
pairs = []
|
| 170 |
for i in range(len(images)):
|
| 171 |
img = images[i]
|
| 172 |
srt = srt_data[i] if i < len(srt_data) else None
|
| 173 |
|
| 174 |
if srt:
|
|
|
|
| 175 |
thumb_bytes = compress_image(img['data'], quality=50, max_width=300)
|
| 176 |
thumb_b64 = base64.b64encode(thumb_bytes).decode('utf-8')
|
| 177 |
|
|
@@ -189,7 +195,7 @@ async def analyze_subtitles(
|
|
| 189 |
if not pairs:
|
| 190 |
return {"status": "error", "message": "No valid image/subtitle pairs found."}
|
| 191 |
|
| 192 |
-
# Process Gemini
|
| 193 |
keys = [k.strip() for k in api_keys.split('\n') if k.strip()]
|
| 194 |
if not keys:
|
| 195 |
raise HTTPException(status_code=400, detail="No API Keys provided")
|
|
@@ -199,7 +205,7 @@ async def analyze_subtitles(
|
|
| 199 |
|
| 200 |
def worker(batch_idx, batch):
|
| 201 |
key = keys[batch_idx % len(keys)]
|
| 202 |
-
return process_batch_gemini(key, batch)
|
| 203 |
|
| 204 |
with ThreadPoolExecutor(max_workers=len(keys)) as executor:
|
| 205 |
futures = [executor.submit(worker, i, b) for i, b in enumerate(batches)]
|
|
@@ -209,6 +215,7 @@ async def analyze_subtitles(
|
|
| 209 |
for item in res:
|
| 210 |
results_map[item['index']] = item
|
| 211 |
|
|
|
|
| 212 |
final_output = []
|
| 213 |
for p in pairs:
|
| 214 |
analysis = results_map.get(p['index'])
|
|
@@ -240,5 +247,4 @@ async def analyze_subtitles(
|
|
| 240 |
finally:
|
| 241 |
shutil.rmtree(temp_dir)
|
| 242 |
|
| 243 |
-
# 4. Serve Static Files (Frontend)
|
| 244 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
|
| 10 |
|
| 11 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 12 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from PIL import Image
|
| 15 |
import rarfile
|
|
|
|
| 22 |
|
| 23 |
app = FastAPI()
|
| 24 |
|
|
|
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
| 27 |
allow_origins=["*"],
|
|
|
|
| 30 |
allow_headers=["*"],
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# --- Utility Functions ---
|
| 34 |
+
|
| 35 |
def parse_srt_time_to_ms(time_str):
|
| 36 |
try:
|
| 37 |
if not time_str: return 0
|
|
|
|
| 79 |
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
| 80 |
|
| 81 |
buffer = io.BytesIO()
|
| 82 |
+
# Quality expects integer 1-95
|
| 83 |
+
img.save(buffer, format="JPEG", quality=int(quality))
|
| 84 |
return buffer.getvalue()
|
| 85 |
except Exception as e:
|
| 86 |
logger.error(f"Image compression failed: {e}")
|
| 87 |
return None
|
| 88 |
|
| 89 |
+
def process_batch_gemini(api_key, items, model_name):
|
| 90 |
try:
|
| 91 |
genai.configure(api_key=api_key)
|
| 92 |
+
model = genai.GenerativeModel(model_name)
|
| 93 |
|
| 94 |
prompt_parts = [
|
| 95 |
"You are a Subtitle Quality Control (QC) bot.",
|
|
|
|
| 114 |
logger.error(f"Gemini API Error with key ...{api_key[-4:]}: {e}")
|
| 115 |
return None
|
| 116 |
|
| 117 |
+
# --- Main Endpoint ---
|
| 118 |
+
|
| 119 |
@app.post("/api/analyze")
|
| 120 |
async def analyze_subtitles(
|
| 121 |
srt_file: UploadFile = File(...),
|
| 122 |
media_files: list[UploadFile] = File(...),
|
| 123 |
api_keys: str = Form(...),
|
| 124 |
+
batch_size: int = Form(20),
|
| 125 |
+
model_name: str = Form("gemini-2.0-flash"),
|
| 126 |
+
compression_quality: float = Form(0.7)
|
| 127 |
):
|
| 128 |
temp_dir = tempfile.mkdtemp()
|
| 129 |
try:
|
| 130 |
+
# Convert float quality (0.1-1.0) to integer (10-100) for PIL
|
| 131 |
+
pil_quality = max(10, min(100, int(compression_quality * 100)))
|
| 132 |
+
|
| 133 |
+
# 1. Read SRT
|
| 134 |
srt_content = (await srt_file.read()).decode('utf-8', errors='ignore')
|
| 135 |
srt_data = parse_srt(srt_content)
|
| 136 |
srt_data.sort(key=lambda x: x['startTimeMs'])
|
| 137 |
|
| 138 |
+
# 2. Process Media
|
| 139 |
images = []
|
|
|
|
| 140 |
for file in media_files:
|
| 141 |
file_path = os.path.join(temp_dir, file.filename)
|
| 142 |
with open(file_path, "wb") as f:
|
| 143 |
shutil.copyfileobj(file.file, f)
|
| 144 |
|
| 145 |
if file.filename.lower().endswith('.rar'):
|
|
|
|
| 146 |
try:
|
| 147 |
with rarfile.RarFile(file_path) as rf:
|
| 148 |
rf.extractall(temp_dir)
|
|
|
|
| 160 |
if ms is not None:
|
| 161 |
with open(full_path, "rb") as f:
|
| 162 |
raw_bytes = f.read()
|
| 163 |
+
compressed = compress_image(raw_bytes, quality=pil_quality)
|
| 164 |
if compressed:
|
| 165 |
images.append({
|
| 166 |
"filename": filename,
|
|
|
|
| 170 |
|
| 171 |
images.sort(key=lambda x: x['timeMs'])
|
| 172 |
|
| 173 |
+
# 3. Pair
|
| 174 |
pairs = []
|
| 175 |
for i in range(len(images)):
|
| 176 |
img = images[i]
|
| 177 |
srt = srt_data[i] if i < len(srt_data) else None
|
| 178 |
|
| 179 |
if srt:
|
| 180 |
+
# Create Thumbnail (lower quality for UI speed)
|
| 181 |
thumb_bytes = compress_image(img['data'], quality=50, max_width=300)
|
| 182 |
thumb_b64 = base64.b64encode(thumb_bytes).decode('utf-8')
|
| 183 |
|
|
|
|
| 195 |
if not pairs:
|
| 196 |
return {"status": "error", "message": "No valid image/subtitle pairs found."}
|
| 197 |
|
| 198 |
+
# 4. Process Gemini
|
| 199 |
keys = [k.strip() for k in api_keys.split('\n') if k.strip()]
|
| 200 |
if not keys:
|
| 201 |
raise HTTPException(status_code=400, detail="No API Keys provided")
|
|
|
|
| 205 |
|
| 206 |
def worker(batch_idx, batch):
|
| 207 |
key = keys[batch_idx % len(keys)]
|
| 208 |
+
return process_batch_gemini(key, batch, model_name)
|
| 209 |
|
| 210 |
with ThreadPoolExecutor(max_workers=len(keys)) as executor:
|
| 211 |
futures = [executor.submit(worker, i, b) for i, b in enumerate(batches)]
|
|
|
|
| 215 |
for item in res:
|
| 216 |
results_map[item['index']] = item
|
| 217 |
|
| 218 |
+
# 5. Build Output
|
| 219 |
final_output = []
|
| 220 |
for p in pairs:
|
| 221 |
analysis = results_map.get(p['index'])
|
|
|
|
| 247 |
finally:
|
| 248 |
shutil.rmtree(temp_dir)
|
| 249 |
|
|
|
|
| 250 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|