Flashcard2Audio / app.py
adelevett's picture
Update app.py
1604b13 verified
import gradio as gr
import pandas as pd
import genanki
from pocket_tts import TTSModel
import tempfile
import os
import shutil
import random
import zipfile
import sqlite3
import re
import time
import json
import torch
import scipy.io.wavfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydub import AudioSegment
# --- Configuration ---
MAX_WORKERS = 4 # Keep low for HF Spaces (CPU/RAM constraint)
PREVIEW_LIMIT = 100 # UI safety cap
PROGRESS_THROTTLE = 1.0 # Seconds between UI updates
# --- Helpers ---
def clean_text_for_tts(text):
"""Deep cleaning for TTS input only."""
if pd.isna(text): return ""
text = str(text)
# Remove HTML tags
text = re.sub(re.compile('<.*?>'), '', text)
# Remove Anki sound tags
text = re.sub(r'\[sound:.*?\]', '', text)
# Remove mustache templates
text = re.sub(r'\{\{.*?\}\}', '', text)
return text.strip()
def has_existing_audio(text):
"""Check if text already contains an Anki sound tag."""
if pd.isna(text): return False
return bool(re.search(r'\[sound:.*?\]', str(text)))
print("Loading TTS Model...")
try:
TTS_MODEL = TTSModel.load_model()
print("Model Loaded Successfully.")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
TTS_MODEL = None
# Get default voice state
VOICE_STATE = None
if TTS_MODEL:
try:
VOICE_STATE = TTS_MODEL.get_state_for_audio_prompt("alba") # Default voice
except Exception as e:
print(f"Warning: Could not load default voice: {e}")
def wav_to_mp3(src_wav, dst_mp3):
AudioSegment.from_wav(src_wav).export(dst_mp3, format="mp3", bitrate="64k")
def generate_audio_for_row(q_text, a_text, idx, tmpdir, mode):
"""
Generates audio. Returns (path_q, path_a).
Returns 'SKIP' if audio exists and we are preserving it.
"""
q_out, a_out = None, None
# Logic for handling modes
# Mode 0: Smart Fill (Preserve Existing)
# Mode 1: Overwrite All
overwrite = (mode == "Generate all new audio (Overwrite)")
# --- Question Processing ---
if not overwrite and has_existing_audio(q_text):
q_out = "SKIP"
else:
q_wav = os.path.join(tmpdir, f"q_{idx}.wav")
try:
clean = clean_text_for_tts(q_text)
if clean and TTS_MODEL and VOICE_STATE:
# Generate audio using new API
audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean)
# Convert tensor to numpy and save as wav
scipy.io.wavfile.write(q_wav, TTS_MODEL.sample_rate, audio_tensor.numpy())
q_out = q_wav
else:
AudioSegment.silent(duration=500).export(q_wav, format="wav")
q_out = q_wav
except Exception as e:
print(f"TTS Error Q row {idx}: {e}")
# Fallback to silence to keep deck integrity
AudioSegment.silent(duration=500).export(q_wav, format="wav")
q_out = q_wav
# --- Answer Processing ---
if not overwrite and has_existing_audio(a_text):
a_out = "SKIP"
else:
a_wav = os.path.join(tmpdir, f"a_{idx}.wav")
try:
clean = clean_text_for_tts(a_text)
if clean and TTS_MODEL and VOICE_STATE:
# Generate audio using new API
audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean)
# Convert tensor to numpy and save as wav
scipy.io.wavfile.write(a_wav, TTS_MODEL.sample_rate, audio_tensor.numpy())
a_out = a_wav
else:
AudioSegment.silent(duration=500).export(a_wav, format="wav")
a_out = a_wav
except Exception as e:
print(f"TTS Error A row {idx}: {e}")
AudioSegment.silent(duration=500).export(a_wav, format="wav")
a_out = a_wav
return q_out, a_out
def strip_html_for_display(text):
"""Remove HTML tags for preview readability."""
if pd.isna(text) or text == "": return ""
text = str(text)
# Remove HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Decode HTML entities
text = text.replace('&nbsp;', ' ').replace('&gt;', '>').replace('&lt;', '<').replace('&amp;', '&')
# Limit length for display
if len(text) > 50:
text = text[:50] + '...'
return text.strip()
def extract_unique_tags(df):
"""Extract all unique tags from the Tags column."""
if df is None or 'Tags' not in df.columns:
return ["All"]
all_tags = set()
for tag_str in df['Tags']:
if tag_str:
# Tags are space-separated, e.g., " MK_MathematicsKnowledge "
tags = [t.strip() for t in tag_str.split() if t.strip()]
all_tags.update(tags)
return ["All"] + sorted(list(all_tags))
def parse_file(file_obj):
if file_obj is None:
return None, None, None, "No file uploaded", "", None
ext = Path(file_obj.name).suffix.lower()
df = pd.DataFrame()
extract_root = None # Directory where we keep original media
has_media = False
try:
if ext == ".csv":
df = pd.read_csv(file_obj.name)
if len(df.columns) < 2:
df = pd.read_csv(file_obj.name, header=None)
if len(df.columns) < 2:
return None, None, None, "CSV error: Need 2 columns", "", None
df = df.iloc[:, :2]
df.columns = ["Question", "Answer"]
df['Tags'] = "" # CSV files don't have tags
elif ext == ".apkg" or ext == ".zip":
# Extract to a PERSISTENT temp dir (passed to state)
extract_root = tempfile.mkdtemp()
with zipfile.ZipFile(file_obj.name, 'r') as z:
z.extractall(extract_root)
col_path = os.path.join(extract_root, "collection.anki2")
if not os.path.exists(col_path):
shutil.rmtree(extract_root)
return None, None, None, "Invalid APKG: No collection.anki2", "", None
conn = sqlite3.connect(col_path)
cur = conn.cursor()
cur.execute("SELECT flds, tags FROM notes")
rows = cur.fetchall()
data = []
audio_count = 0 # Count cards with existing audio
for r in rows:
flds = r[0].split('\x1f')
tags = r[1].strip() if len(r) > 1 else ""
q = flds[0] if len(flds) > 0 else ""
a = flds[1] if len(flds) > 1 else ""
# Check if either field has audio tags
if re.search(r'\[sound:.*?\]', q) or re.search(r'\[sound:.*?\]', a):
audio_count += 1
data.append([q, a, tags])
df = pd.DataFrame(data, columns=["Question", "Answer", "Tags"])
conn.close()
# has_media means existing AUDIO, not images
has_media = audio_count > 0
else:
return None, None, None, "Unsupported file type", "", None
df = df.fillna("")
msg = f"✅ Loaded {len(df)} cards."
if has_media:
msg += f" 🎵 {audio_count} cards have existing audio."
return df, has_media, df.head(PREVIEW_LIMIT), msg, estimate_time(len(df), has_media), extract_root
except Exception as e:
if extract_root and os.path.exists(extract_root):
shutil.rmtree(extract_root)
return None, None, None, f"Error: {str(e)}", "", None
def estimate_time(num_cards, has_existing_media=False, mode="Smart Fill (Preserve Existing)"):
"""
Estimate based on benchmark: ~4.7s per card for full generation.
Adjusts for Smart Fill mode when existing media is present.
"""
if num_cards == 0:
return "0s"
# Base benchmark: 4.7s per card for full audio generation
seconds_per_card = 4.7
# If using Smart Fill with existing media, assume ~50% speedup (many cards already have audio)
if has_existing_media and "Smart Fill" in mode:
seconds_per_card *= 0.5
seconds = num_cards * seconds_per_card
if seconds < 60:
return f"~{int(seconds)}s"
elif seconds < 3600:
return f"~{int(seconds/60)} min"
else:
hours = int(seconds / 3600)
mins = int((seconds % 3600) / 60)
return f"~{hours}h {mins}m" if mins > 0 else f"~{hours}h"
def process_dataframe(df_full, search_term, extract_root, mode, search_in, selected_tag, progress=gr.Progress()):
if df_full is None or len(df_full) == 0:
return None, "No data"
# Start with full dataframe
df = df_full.copy()
# Apply tag filter first
if selected_tag and selected_tag != "All":
df = df[df['Tags'].str.contains(selected_tag, na=False, case=False)]
# Apply text search filter
if search_term:
if search_in == "Question Only":
mask = df['Question'].str.contains(search_term, case=False, na=False)
elif search_in == "Answer Only":
mask = df['Answer'].str.contains(search_term, case=False, na=False)
else: # Both
mask = df.astype(str).apply(lambda x: x.str.contains(search_term, case=False, na=False)).any(axis=1)
df = df[mask]
if len(df) == 0:
return None, "No matching cards"
# Setup
work_dir = tempfile.mkdtemp()
media_files = []
try:
# --- Media Preservation Logic ---
if extract_root:
media_map_path = os.path.join(extract_root, "media")
if os.path.exists(media_map_path) and os.path.getsize(media_map_path) > 0:
try:
with open(media_map_path, 'r') as f:
# Fix: Handle potentially malformed JSON gracefully
content = f.read().strip()
if content:
media_map = json.loads(content) # {"0": "my_audio.mp3", ...}
# Rename files in extract_root back to original names
for k, v in media_map.items():
src = os.path.join(extract_root, k)
dst = os.path.join(extract_root, v)
if os.path.exists(src):
# Rename enables genanki to find them by name
os.rename(src, dst)
media_files.append(dst)
else:
print("Warning: Media map file is empty.")
except Exception as e:
print(f"Warning: Could not restore existing media: {e}")
# --- Genanki Setup ---
model_id = random.randrange(1 << 30, 1 << 31)
my_model = genanki.Model(
model_id, 'PocketTTS Model',
fields=[{'name': 'Question'}, {'name': 'Answer'}],
templates=[{
'name': 'Card 1',
'qfmt': '{{Question}}',
'afmt': '{{FrontSide}}<hr id="answer">{{Answer}}',
}])
my_deck = genanki.Deck(random.randrange(1 << 30, 1 << 31), 'Pocket TTS Deck')
# --- Execution ---
total = len(df)
completed = 0
last_update_time = 0
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe:
futures = {}
for idx, row in df.iterrows():
f = exe.submit(generate_audio_for_row, str(row['Question']), str(row['Answer']), idx, work_dir, mode)
futures[f] = idx
for future in as_completed(futures):
idx = futures[future]
try:
q_res, a_res = future.result()
# --- Field Construction (Corrected) ---
q_original = str(df.loc[idx, 'Question'])
q_field = q_original
# Update Question
if q_res and q_res != "SKIP":
q_mp3 = str(Path(q_res).with_suffix('.mp3'))
wav_to_mp3(q_res, q_mp3)
os.remove(q_res) # clean wav
media_files.append(q_mp3)
# Remove OLD sound tags first to avoid duplicates
q_field = re.sub(r'\[sound:.*?\]', '', q_field)
q_field = q_field.strip() + f"<br>[sound:{os.path.basename(q_mp3)}]"
# Update Answer
a_original = str(df.loc[idx, 'Answer'])
a_field = a_original
if a_res and a_res != "SKIP":
a_mp3 = str(Path(a_res).with_suffix('.mp3'))
wav_to_mp3(a_res, a_mp3)
os.remove(a_res) # clean wav
media_files.append(a_mp3)
# Remove OLD sound tags first
a_field = re.sub(r'\[sound:.*?\]', '', a_field)
a_field = a_field.strip() + f"<br>[sound:{os.path.basename(a_mp3)}]"
# Add Note
note = genanki.Note(
model=my_model,
fields=[q_field, a_field]
)
my_deck.add_note(note)
except Exception as e:
print(f"Row {idx} failed: {e}")
# --- Throttled Progress ---
completed += 1
current_time = time.time()
if completed == total or (current_time - last_update_time) > PROGRESS_THROTTLE:
progress(completed / total, desc=f"Processed {completed}/{total}")
last_update_time = current_time
# --- Package ---
package = genanki.Package(my_deck)
# Deduplicate media files list
package.media_files = list(set(media_files))
raw_out = os.path.join(work_dir, "output.apkg")
package.write_to_file(raw_out)
final_out = os.path.join(tempfile.gettempdir(), f"pocket_deck_{random.randint(1000,9999)}.apkg")
shutil.copy(raw_out, final_out)
return final_out, f"✅ Done! Packaged {len(package.media_files)} audio files."
except Exception as e:
return None, f"Critical Error: {str(e)}"
finally:
# --- Guaranteed Cleanup ---
if os.path.exists(work_dir):
shutil.rmtree(work_dir)
# Also clean up the input extraction root if it exists
if extract_root and os.path.exists(extract_root):
shutil.rmtree(extract_root)
# --- UI ---
with gr.Blocks(title="Pocket TTS Anki") as app:
gr.Markdown("## 🎴 Pocket TTS Anki Generator")
gr.Markdown("Offline Neural Audio. Supports CSV and APKG (smart media preservation).")
# State variables
full_df_state = gr.State()
extract_root_state = gr.State() # Holds path to unzipped APKG
with gr.Row():
file_input = gr.File(label="Upload (CSV/APKG)", file_types=[".csv", ".apkg", ".zip"])
status = gr.Textbox(label="Status", interactive=False)
eta_box = gr.Textbox(label="Est. Time", interactive=False)
with gr.Row():
search_box = gr.Textbox(label="Search Text", placeholder="Enter text to search...")
search_field = gr.Radio(
choices=["Both", "Question Only", "Answer Only"],
value="Both",
label="Search In"
)
tag_dropdown = gr.Dropdown(
label="Filter by Tag",
choices=["All"],
value="All",
interactive=True
)
with gr.Row():
# New 3-Way Toggle
mode_radio = gr.Radio(
choices=[
"Smart Fill (Preserve Existing)",
"Generate all new audio (Overwrite)",
"Only generate missing (Same as Smart Fill)"
],
value="Smart Fill (Preserve Existing)",
label="Generation Mode"
)
preview_table = gr.Dataframe(
label="Preview (First 100)",
interactive=False,
column_widths=["30%", "45%", "25%"]
)
with gr.Row():
btn = gr.Button("🚀 Generate Deck", variant="primary")
dl = gr.File(label="Download")
result_lbl = gr.Textbox(label="Result", interactive=False)
has_media_state = gr.State(False)
def on_upload(file):
# Returns: df, has_media, preview, msg, eta, extract_path
df, has_media, preview, msg, eta, ext_path = parse_file(file)
# Extract tags and create cleaned preview
tag_choices = extract_unique_tags(df)
if df is not None:
display_df = df.copy()
display_df['Question'] = display_df['Question'].apply(strip_html_for_display)
display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display)
clean_preview = display_df.head(PREVIEW_LIMIT)
else:
clean_preview = preview
return (
df, # full_df_state
has_media, # has_media_state
clean_preview, # preview_table
msg, # status
eta, # eta_box
ext_path, # extract_root_state
gr.Dropdown(choices=tag_choices, value="All") # tag_dropdown
)
file_input.upload(on_upload, inputs=file_input,
outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown])
def on_clear():
"""Reset all fields when file is cleared."""
return (
None, # full_df_state
False, # has_media_state
None, # preview_table
"", # status
"", # eta_box
None, # extract_root_state
gr.Dropdown(choices=["All"], value="All"), # tag_dropdown
"", # search_box
None, # dl (download file)
"" # result_lbl
)
file_input.clear(on_clear, inputs=[],
outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown, search_box, dl, result_lbl])
def on_search(term, df, has_media, mode, search_in, selected_tag):
if df is None: return None, "No data"
filtered_df = df.copy()
# Apply tag filter first
if selected_tag and selected_tag != "All":
filtered_df = filtered_df[filtered_df['Tags'].str.contains(selected_tag, na=False, case=False)]
# Apply text search
if term:
if search_in == "Question Only":
mask = filtered_df['Question'].str.contains(term, case=False, na=False)
elif search_in == "Answer Only":
mask = filtered_df['Answer'].str.contains(term, case=False, na=False)
else: # Both
mask = filtered_df.astype(str).apply(lambda x: x.str.contains(term, case=False, na=False)).any(axis=1)
filtered_df = filtered_df[mask]
# Create cleaned display version
display_df = filtered_df.copy()
display_df['Question'] = display_df['Question'].apply(strip_html_for_display)
display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display)
return display_df.head(PREVIEW_LIMIT), estimate_time(len(filtered_df), has_media, mode)
search_box.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
search_field.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
tag_dropdown.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
mode_radio.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown],
outputs=[preview_table, eta_box])
btn.click(process_dataframe,
inputs=[full_df_state, search_box, extract_root_state, mode_radio, search_field, tag_dropdown],
outputs=[dl, result_lbl])
if __name__ == "__main__":
app.queue(max_size=2).launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)