Spaces:
Sleeping
Sleeping
File size: 21,562 Bytes
bab49e8 6f37f73 bab49e8 6f37f73 bab49e8 cd3ce32 bab49e8 6f37f73 bab49e8 6f37f73 bab49e8 6f37f73 bab49e8 6f37f73 bab49e8 0eb1322 cd3ce32 0eb1322 bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 4692a3e bab49e8 0eb1322 bab49e8 4692a3e 0eb1322 bab49e8 0eb1322 bab49e8 4692a3e bab49e8 4692a3e bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 958f060 bab49e8 958f060 bab49e8 958f060 bab49e8 0eb1322 bab49e8 958f060 bab49e8 958f060 bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 0eb1322 bab49e8 4692a3e 0eb1322 bab49e8 0eb1322 bab49e8 958f060 bab49e8 1604b13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 | import gradio as gr
import pandas as pd
import genanki
from pocket_tts import TTSModel
import tempfile
import os
import shutil
import random
import zipfile
import sqlite3
import re
import time
import json
import torch
import scipy.io.wavfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydub import AudioSegment
# --- Configuration ---
MAX_WORKERS = 4 # Keep low for HF Spaces (CPU/RAM constraint)
PREVIEW_LIMIT = 100 # UI safety cap
PROGRESS_THROTTLE = 1.0 # Seconds between UI updates
# --- Helpers ---
def clean_text_for_tts(text):
"""Deep cleaning for TTS input only."""
if pd.isna(text): return ""
text = str(text)
# Remove HTML tags
text = re.sub(re.compile('<.*?>'), '', text)
# Remove Anki sound tags
text = re.sub(r'\[sound:.*?\]', '', text)
# Remove mustache templates
text = re.sub(r'\{\{.*?\}\}', '', text)
return text.strip()
def has_existing_audio(text):
"""Check if text already contains an Anki sound tag."""
if pd.isna(text): return False
return bool(re.search(r'\[sound:.*?\]', str(text)))
print("Loading TTS Model...")
try:
TTS_MODEL = TTSModel.load_model()
print("Model Loaded Successfully.")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
TTS_MODEL = None
# Get default voice state
VOICE_STATE = None
if TTS_MODEL:
try:
VOICE_STATE = TTS_MODEL.get_state_for_audio_prompt("alba") # Default voice
except Exception as e:
print(f"Warning: Could not load default voice: {e}")
def wav_to_mp3(src_wav, dst_mp3):
AudioSegment.from_wav(src_wav).export(dst_mp3, format="mp3", bitrate="64k")
def generate_audio_for_row(q_text, a_text, idx, tmpdir, mode):
"""
Generates audio. Returns (path_q, path_a).
Returns 'SKIP' if audio exists and we are preserving it.
"""
q_out, a_out = None, None
# Logic for handling modes
# Mode 0: Smart Fill (Preserve Existing)
# Mode 1: Overwrite All
overwrite = (mode == "Generate all new audio (Overwrite)")
# --- Question Processing ---
if not overwrite and has_existing_audio(q_text):
q_out = "SKIP"
else:
q_wav = os.path.join(tmpdir, f"q_{idx}.wav")
try:
clean = clean_text_for_tts(q_text)
if clean and TTS_MODEL and VOICE_STATE:
# Generate audio using new API
audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean)
# Convert tensor to numpy and save as wav
scipy.io.wavfile.write(q_wav, TTS_MODEL.sample_rate, audio_tensor.numpy())
q_out = q_wav
else:
AudioSegment.silent(duration=500).export(q_wav, format="wav")
q_out = q_wav
except Exception as e:
print(f"TTS Error Q row {idx}: {e}")
# Fallback to silence to keep deck integrity
AudioSegment.silent(duration=500).export(q_wav, format="wav")
q_out = q_wav
# --- Answer Processing ---
if not overwrite and has_existing_audio(a_text):
a_out = "SKIP"
else:
a_wav = os.path.join(tmpdir, f"a_{idx}.wav")
try:
clean = clean_text_for_tts(a_text)
if clean and TTS_MODEL and VOICE_STATE:
# Generate audio using new API
audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean)
# Convert tensor to numpy and save as wav
scipy.io.wavfile.write(a_wav, TTS_MODEL.sample_rate, audio_tensor.numpy())
a_out = a_wav
else:
AudioSegment.silent(duration=500).export(a_wav, format="wav")
a_out = a_wav
except Exception as e:
print(f"TTS Error A row {idx}: {e}")
AudioSegment.silent(duration=500).export(a_wav, format="wav")
a_out = a_wav
return q_out, a_out
def strip_html_for_display(text):
"""Remove HTML tags for preview readability."""
if pd.isna(text) or text == "": return ""
text = str(text)
# Remove HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Decode HTML entities
text = text.replace(' ', ' ').replace('>', '>').replace('<', '<').replace('&', '&')
# Limit length for display
if len(text) > 50:
text = text[:50] + '...'
return text.strip()
def extract_unique_tags(df):
"""Extract all unique tags from the Tags column."""
if df is None or 'Tags' not in df.columns:
return ["All"]
all_tags = set()
for tag_str in df['Tags']:
if tag_str:
# Tags are space-separated, e.g., " MK_MathematicsKnowledge "
tags = [t.strip() for t in tag_str.split() if t.strip()]
all_tags.update(tags)
return ["All"] + sorted(list(all_tags))
def parse_file(file_obj):
if file_obj is None:
return None, None, None, "No file uploaded", "", None
ext = Path(file_obj.name).suffix.lower()
df = pd.DataFrame()
extract_root = None # Directory where we keep original media
has_media = False
try:
if ext == ".csv":
df = pd.read_csv(file_obj.name)
if len(df.columns) < 2:
df = pd.read_csv(file_obj.name, header=None)
if len(df.columns) < 2:
return None, None, None, "CSV error: Need 2 columns", "", None
df = df.iloc[:, :2]
df.columns = ["Question", "Answer"]
df['Tags'] = "" # CSV files don't have tags
elif ext == ".apkg" or ext == ".zip":
# Extract to a PERSISTENT temp dir (passed to state)
extract_root = tempfile.mkdtemp()
with zipfile.ZipFile(file_obj.name, 'r') as z:
z.extractall(extract_root)
col_path = os.path.join(extract_root, "collection.anki2")
if not os.path.exists(col_path):
shutil.rmtree(extract_root)
return None, None, None, "Invalid APKG: No collection.anki2", "", None
conn = sqlite3.connect(col_path)
cur = conn.cursor()
cur.execute("SELECT flds, tags FROM notes")
rows = cur.fetchall()
data = []
audio_count = 0 # Count cards with existing audio
for r in rows:
flds = r[0].split('\x1f')
tags = r[1].strip() if len(r) > 1 else ""
q = flds[0] if len(flds) > 0 else ""
a = flds[1] if len(flds) > 1 else ""
# Check if either field has audio tags
if re.search(r'\[sound:.*?\]', q) or re.search(r'\[sound:.*?\]', a):
audio_count += 1
data.append([q, a, tags])
df = pd.DataFrame(data, columns=["Question", "Answer", "Tags"])
conn.close()
# has_media means existing AUDIO, not images
has_media = audio_count > 0
else:
return None, None, None, "Unsupported file type", "", None
df = df.fillna("")
msg = f"✅ Loaded {len(df)} cards."
if has_media:
msg += f" 🎵 {audio_count} cards have existing audio."
return df, has_media, df.head(PREVIEW_LIMIT), msg, estimate_time(len(df), has_media), extract_root
except Exception as e:
if extract_root and os.path.exists(extract_root):
shutil.rmtree(extract_root)
return None, None, None, f"Error: {str(e)}", "", None
def estimate_time(num_cards, has_existing_media=False, mode="Smart Fill (Preserve Existing)"):
"""
Estimate based on benchmark: ~4.7s per card for full generation.
Adjusts for Smart Fill mode when existing media is present.
"""
if num_cards == 0:
return "0s"
# Base benchmark: 4.7s per card for full audio generation
seconds_per_card = 4.7
# If using Smart Fill with existing media, assume ~50% speedup (many cards already have audio)
if has_existing_media and "Smart Fill" in mode:
seconds_per_card *= 0.5
seconds = num_cards * seconds_per_card
if seconds < 60:
return f"~{int(seconds)}s"
elif seconds < 3600:
return f"~{int(seconds/60)} min"
else:
hours = int(seconds / 3600)
mins = int((seconds % 3600) / 60)
return f"~{hours}h {mins}m" if mins > 0 else f"~{hours}h"
def process_dataframe(df_full, search_term, extract_root, mode, search_in, selected_tag, progress=gr.Progress()):
if df_full is None or len(df_full) == 0:
return None, "No data"
# Start with full dataframe
df = df_full.copy()
# Apply tag filter first
if selected_tag and selected_tag != "All":
df = df[df['Tags'].str.contains(selected_tag, na=False, case=False)]
# Apply text search filter
if search_term:
if search_in == "Question Only":
mask = df['Question'].str.contains(search_term, case=False, na=False)
elif search_in == "Answer Only":
mask = df['Answer'].str.contains(search_term, case=False, na=False)
else: # Both
mask = df.astype(str).apply(lambda x: x.str.contains(search_term, case=False, na=False)).any(axis=1)
df = df[mask]
if len(df) == 0:
return None, "No matching cards"
# Setup
work_dir = tempfile.mkdtemp()
media_files = []
try:
# --- Media Preservation Logic ---
if extract_root:
media_map_path = os.path.join(extract_root, "media")
if os.path.exists(media_map_path) and os.path.getsize(media_map_path) > 0:
try:
with open(media_map_path, 'r') as f:
# Fix: Handle potentially malformed JSON gracefully
content = f.read().strip()
if content:
media_map = json.loads(content) # {"0": "my_audio.mp3", ...}
# Rename files in extract_root back to original names
for k, v in media_map.items():
src = os.path.join(extract_root, k)
dst = os.path.join(extract_root, v)
if os.path.exists(src):
# Rename enables genanki to find them by name
os.rename(src, dst)
media_files.append(dst)
else:
print("Warning: Media map file is empty.")
except Exception as e:
print(f"Warning: Could not restore existing media: {e}")
# --- Genanki Setup ---
model_id = random.randrange(1 << 30, 1 << 31)
my_model = genanki.Model(
model_id, 'PocketTTS Model',
fields=[{'name': 'Question'}, {'name': 'Answer'}],
templates=[{
'name': 'Card 1',
'qfmt': '{{Question}}',
'afmt': '{{FrontSide}}<hr id="answer">{{Answer}}',
}])
my_deck = genanki.Deck(random.randrange(1 << 30, 1 << 31), 'Pocket TTS Deck')
# --- Execution ---
total = len(df)
completed = 0
last_update_time = 0
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe:
futures = {}
for idx, row in df.iterrows():
f = exe.submit(generate_audio_for_row, str(row['Question']), str(row['Answer']), idx, work_dir, mode)
futures[f] = idx
for future in as_completed(futures):
idx = futures[future]
try:
q_res, a_res = future.result()
# --- Field Construction (Corrected) ---
q_original = str(df.loc[idx, 'Question'])
q_field = q_original
# Update Question
if q_res and q_res != "SKIP":
q_mp3 = str(Path(q_res).with_suffix('.mp3'))
wav_to_mp3(q_res, q_mp3)
os.remove(q_res) # clean wav
media_files.append(q_mp3)
# Remove OLD sound tags first to avoid duplicates
q_field = re.sub(r'\[sound:.*?\]', '', q_field)
q_field = q_field.strip() + f"<br>[sound:{os.path.basename(q_mp3)}]"
# Update Answer
a_original = str(df.loc[idx, 'Answer'])
a_field = a_original
if a_res and a_res != "SKIP":
a_mp3 = str(Path(a_res).with_suffix('.mp3'))
wav_to_mp3(a_res, a_mp3)
os.remove(a_res) # clean wav
media_files.append(a_mp3)
# Remove OLD sound tags first
a_field = re.sub(r'\[sound:.*?\]', '', a_field)
a_field = a_field.strip() + f"<br>[sound:{os.path.basename(a_mp3)}]"
# Add Note
note = genanki.Note(
model=my_model,
fields=[q_field, a_field]
)
my_deck.add_note(note)
except Exception as e:
print(f"Row {idx} failed: {e}")
# --- Throttled Progress ---
completed += 1
current_time = time.time()
if completed == total or (current_time - last_update_time) > PROGRESS_THROTTLE:
progress(completed / total, desc=f"Processed {completed}/{total}")
last_update_time = current_time
# --- Package ---
package = genanki.Package(my_deck)
# Deduplicate media files list
package.media_files = list(set(media_files))
raw_out = os.path.join(work_dir, "output.apkg")
package.write_to_file(raw_out)
final_out = os.path.join(tempfile.gettempdir(), f"pocket_deck_{random.randint(1000,9999)}.apkg")
shutil.copy(raw_out, final_out)
return final_out, f"✅ Done! Packaged {len(package.media_files)} audio files."
except Exception as e:
return None, f"Critical Error: {str(e)}"
finally:
# --- Guaranteed Cleanup ---
if os.path.exists(work_dir):
shutil.rmtree(work_dir)
# Also clean up the input extraction root if it exists
if extract_root and os.path.exists(extract_root):
shutil.rmtree(extract_root)
# --- UI ---
with gr.Blocks(title="Pocket TTS Anki") as app:
gr.Markdown("## 🎴 Pocket TTS Anki Generator")
gr.Markdown("Offline Neural Audio. Supports CSV and APKG (smart media preservation).")
# State variables
full_df_state = gr.State()
extract_root_state = gr.State() # Holds path to unzipped APKG
with gr.Row():
file_input = gr.File(label="Upload (CSV/APKG)", file_types=[".csv", ".apkg", ".zip"])
status = gr.Textbox(label="Status", interactive=False)
eta_box = gr.Textbox(label="Est. Time", interactive=False)
with gr.Row():
search_box = gr.Textbox(label="Search Text", placeholder="Enter text to search...")
search_field = gr.Radio(
choices=["Both", "Question Only", "Answer Only"],
value="Both",
label="Search In"
)
tag_dropdown = gr.Dropdown(
label="Filter by Tag",
choices=["All"],
value="All",
interactive=True
)
with gr.Row():
# New 3-Way Toggle
mode_radio = gr.Radio(
choices=[
"Smart Fill (Preserve Existing)",
"Generate all new audio (Overwrite)",
"Only generate missing (Same as Smart Fill)"
],
value="Smart Fill (Preserve Existing)",
label="Generation Mode"
)
preview_table = gr.Dataframe(
label="Preview (First 100)",
interactive=False,
column_widths=["30%", "45%", "25%"]
)
with gr.Row():
btn = gr.Button("🚀 Generate Deck", variant="primary")
dl = gr.File(label="Download")
result_lbl = gr.Textbox(label="Result", interactive=False)
has_media_state = gr.State(False)
def on_upload(file):
# Returns: df, has_media, preview, msg, eta, extract_path
df, has_media, preview, msg, eta, ext_path = parse_file(file)
# Extract tags and create cleaned preview
tag_choices = extract_unique_tags(df)
if df is not None:
display_df = df.copy()
display_df['Question'] = display_df['Question'].apply(strip_html_for_display)
display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display)
clean_preview = display_df.head(PREVIEW_LIMIT)
else:
clean_preview = preview
return (
df, # full_df_state
has_media, # has_media_state
clean_preview, # preview_table
msg, # status
eta, # eta_box
ext_path, # extract_root_state
gr.Dropdown(choices=tag_choices, value="All") # tag_dropdown
)
file_input.upload(on_upload, inputs=file_input,
outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown])
def on_clear():
"""Reset all fields when file is cleared."""
return (
None, # full_df_state
False, # has_media_state
None, # preview_table
"", # status
"", # eta_box
None, # extract_root_state
gr.Dropdown(choices=["All"], value="All"), # tag_dropdown
"", # search_box
None, # dl (download file)
"" # result_lbl
)
file_input.clear(on_clear, inputs=[],
outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown, search_box, dl, result_lbl])
def on_search(term, df, has_media, mode, search_in, selected_tag):
if df is None: return None, "No data"
filtered_df = df.copy()
# Apply tag filter first
if selected_tag and selected_tag != "All":
filtered_df = filtered_df[filtered_df['Tags'].str.contains(selected_tag, na=False, case=False)]
# Apply text search
if term:
if search_in == "Question Only":
mask = filtered_df['Question'].str.contains(term, case=False, na=False)
elif search_in == "Answer Only":
mask = filtered_df['Answer'].str.contains(term, case=False, na=False)
else: # Both
mask = filtered_df.astype(str).apply(lambda x: x.str.contains(term, case=False, na=False)).any(axis=1)
filtered_df = filtered_df[mask]
# Create cleaned display version
display_df = filtered_df.copy()
display_df['Question'] = display_df['Question'].apply(strip_html_for_display)
display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display)
return display_df.head(PREVIEW_LIMIT), estimate_time(len(filtered_df), has_media, mode)
search_box.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
search_field.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
tag_dropdown.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
mode_radio.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
btn.click(process_dataframe,
inputs=[full_df_state, search_box, extract_root_state, mode_radio, search_field, tag_dropdown],
outputs=[dl, result_lbl])
if __name__ == "__main__":
app.queue(max_size=2).launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)
|