import gradio as gr import pandas as pd import genanki from pocket_tts import TTSModel import tempfile import os import shutil import random import zipfile import sqlite3 import re import time import json import torch import scipy.io.wavfile from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from pydub import AudioSegment # --- Configuration --- MAX_WORKERS = 4 # Keep low for HF Spaces (CPU/RAM constraint) PREVIEW_LIMIT = 100 # UI safety cap PROGRESS_THROTTLE = 1.0 # Seconds between UI updates # --- Helpers --- def clean_text_for_tts(text): """Deep cleaning for TTS input only.""" if pd.isna(text): return "" text = str(text) # Remove HTML tags text = re.sub(re.compile('<.*?>'), '', text) # Remove Anki sound tags text = re.sub(r'\[sound:.*?\]', '', text) # Remove mustache templates text = re.sub(r'\{\{.*?\}\}', '', text) return text.strip() def has_existing_audio(text): """Check if text already contains an Anki sound tag.""" if pd.isna(text): return False return bool(re.search(r'\[sound:.*?\]', str(text))) print("Loading TTS Model...") try: TTS_MODEL = TTSModel.load_model() print("Model Loaded Successfully.") except Exception as e: print(f"CRITICAL ERROR loading model: {e}") TTS_MODEL = None # Get default voice state VOICE_STATE = None if TTS_MODEL: try: VOICE_STATE = TTS_MODEL.get_state_for_audio_prompt("alba") # Default voice except Exception as e: print(f"Warning: Could not load default voice: {e}") def wav_to_mp3(src_wav, dst_mp3): AudioSegment.from_wav(src_wav).export(dst_mp3, format="mp3", bitrate="64k") def generate_audio_for_row(q_text, a_text, idx, tmpdir, mode): """ Generates audio. Returns (path_q, path_a). Returns 'SKIP' if audio exists and we are preserving it. """ q_out, a_out = None, None # Logic for handling modes # Mode 0: Smart Fill (Preserve Existing) # Mode 1: Overwrite All overwrite = (mode == "Generate all new audio (Overwrite)") # --- Question Processing --- if not overwrite and has_existing_audio(q_text): q_out = "SKIP" else: q_wav = os.path.join(tmpdir, f"q_{idx}.wav") try: clean = clean_text_for_tts(q_text) if clean and TTS_MODEL and VOICE_STATE: # Generate audio using new API audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) # Convert tensor to numpy and save as wav scipy.io.wavfile.write(q_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) q_out = q_wav else: AudioSegment.silent(duration=500).export(q_wav, format="wav") q_out = q_wav except Exception as e: print(f"TTS Error Q row {idx}: {e}") # Fallback to silence to keep deck integrity AudioSegment.silent(duration=500).export(q_wav, format="wav") q_out = q_wav # --- Answer Processing --- if not overwrite and has_existing_audio(a_text): a_out = "SKIP" else: a_wav = os.path.join(tmpdir, f"a_{idx}.wav") try: clean = clean_text_for_tts(a_text) if clean and TTS_MODEL and VOICE_STATE: # Generate audio using new API audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) # Convert tensor to numpy and save as wav scipy.io.wavfile.write(a_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) a_out = a_wav else: AudioSegment.silent(duration=500).export(a_wav, format="wav") a_out = a_wav except Exception as e: print(f"TTS Error A row {idx}: {e}") AudioSegment.silent(duration=500).export(a_wav, format="wav") a_out = a_wav return q_out, a_out def strip_html_for_display(text): """Remove HTML tags for preview readability.""" if pd.isna(text) or text == "": return "" text = str(text) # Remove HTML tags text = re.sub(r'<[^>]+>', '', text) # Decode HTML entities text = text.replace(' ', ' ').replace('>', '>').replace('<', '<').replace('&', '&') # Limit length for display if len(text) > 50: text = text[:50] + '...' return text.strip() def extract_unique_tags(df): """Extract all unique tags from the Tags column.""" if df is None or 'Tags' not in df.columns: return ["All"] all_tags = set() for tag_str in df['Tags']: if tag_str: # Tags are space-separated, e.g., " MK_MathematicsKnowledge " tags = [t.strip() for t in tag_str.split() if t.strip()] all_tags.update(tags) return ["All"] + sorted(list(all_tags)) def parse_file(file_obj): if file_obj is None: return None, None, None, "No file uploaded", "", None ext = Path(file_obj.name).suffix.lower() df = pd.DataFrame() extract_root = None # Directory where we keep original media has_media = False try: if ext == ".csv": df = pd.read_csv(file_obj.name) if len(df.columns) < 2: df = pd.read_csv(file_obj.name, header=None) if len(df.columns) < 2: return None, None, None, "CSV error: Need 2 columns", "", None df = df.iloc[:, :2] df.columns = ["Question", "Answer"] df['Tags'] = "" # CSV files don't have tags elif ext == ".apkg" or ext == ".zip": # Extract to a PERSISTENT temp dir (passed to state) extract_root = tempfile.mkdtemp() with zipfile.ZipFile(file_obj.name, 'r') as z: z.extractall(extract_root) col_path = os.path.join(extract_root, "collection.anki2") if not os.path.exists(col_path): shutil.rmtree(extract_root) return None, None, None, "Invalid APKG: No collection.anki2", "", None conn = sqlite3.connect(col_path) cur = conn.cursor() cur.execute("SELECT flds, tags FROM notes") rows = cur.fetchall() data = [] audio_count = 0 # Count cards with existing audio for r in rows: flds = r[0].split('\x1f') tags = r[1].strip() if len(r) > 1 else "" q = flds[0] if len(flds) > 0 else "" a = flds[1] if len(flds) > 1 else "" # Check if either field has audio tags if re.search(r'\[sound:.*?\]', q) or re.search(r'\[sound:.*?\]', a): audio_count += 1 data.append([q, a, tags]) df = pd.DataFrame(data, columns=["Question", "Answer", "Tags"]) conn.close() # has_media means existing AUDIO, not images has_media = audio_count > 0 else: return None, None, None, "Unsupported file type", "", None df = df.fillna("") msg = f"✅ Loaded {len(df)} cards." if has_media: msg += f" 🎵 {audio_count} cards have existing audio." return df, has_media, df.head(PREVIEW_LIMIT), msg, estimate_time(len(df), has_media), extract_root except Exception as e: if extract_root and os.path.exists(extract_root): shutil.rmtree(extract_root) return None, None, None, f"Error: {str(e)}", "", None def estimate_time(num_cards, has_existing_media=False, mode="Smart Fill (Preserve Existing)"): """ Estimate based on benchmark: ~4.7s per card for full generation. Adjusts for Smart Fill mode when existing media is present. """ if num_cards == 0: return "0s" # Base benchmark: 4.7s per card for full audio generation seconds_per_card = 4.7 # If using Smart Fill with existing media, assume ~50% speedup (many cards already have audio) if has_existing_media and "Smart Fill" in mode: seconds_per_card *= 0.5 seconds = num_cards * seconds_per_card if seconds < 60: return f"~{int(seconds)}s" elif seconds < 3600: return f"~{int(seconds/60)} min" else: hours = int(seconds / 3600) mins = int((seconds % 3600) / 60) return f"~{hours}h {mins}m" if mins > 0 else f"~{hours}h" def process_dataframe(df_full, search_term, extract_root, mode, search_in, selected_tag, progress=gr.Progress()): if df_full is None or len(df_full) == 0: return None, "No data" # Start with full dataframe df = df_full.copy() # Apply tag filter first if selected_tag and selected_tag != "All": df = df[df['Tags'].str.contains(selected_tag, na=False, case=False)] # Apply text search filter if search_term: if search_in == "Question Only": mask = df['Question'].str.contains(search_term, case=False, na=False) elif search_in == "Answer Only": mask = df['Answer'].str.contains(search_term, case=False, na=False) else: # Both mask = df.astype(str).apply(lambda x: x.str.contains(search_term, case=False, na=False)).any(axis=1) df = df[mask] if len(df) == 0: return None, "No matching cards" # Setup work_dir = tempfile.mkdtemp() media_files = [] try: # --- Media Preservation Logic --- if extract_root: media_map_path = os.path.join(extract_root, "media") if os.path.exists(media_map_path) and os.path.getsize(media_map_path) > 0: try: with open(media_map_path, 'r') as f: # Fix: Handle potentially malformed JSON gracefully content = f.read().strip() if content: media_map = json.loads(content) # {"0": "my_audio.mp3", ...} # Rename files in extract_root back to original names for k, v in media_map.items(): src = os.path.join(extract_root, k) dst = os.path.join(extract_root, v) if os.path.exists(src): # Rename enables genanki to find them by name os.rename(src, dst) media_files.append(dst) else: print("Warning: Media map file is empty.") except Exception as e: print(f"Warning: Could not restore existing media: {e}") # --- Genanki Setup --- model_id = random.randrange(1 << 30, 1 << 31) my_model = genanki.Model( model_id, 'PocketTTS Model', fields=[{'name': 'Question'}, {'name': 'Answer'}], templates=[{ 'name': 'Card 1', 'qfmt': '{{Question}}', 'afmt': '{{FrontSide}}
{{Answer}}', }]) my_deck = genanki.Deck(random.randrange(1 << 30, 1 << 31), 'Pocket TTS Deck') # --- Execution --- total = len(df) completed = 0 last_update_time = 0 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe: futures = {} for idx, row in df.iterrows(): f = exe.submit(generate_audio_for_row, str(row['Question']), str(row['Answer']), idx, work_dir, mode) futures[f] = idx for future in as_completed(futures): idx = futures[future] try: q_res, a_res = future.result() # --- Field Construction (Corrected) --- q_original = str(df.loc[idx, 'Question']) q_field = q_original # Update Question if q_res and q_res != "SKIP": q_mp3 = str(Path(q_res).with_suffix('.mp3')) wav_to_mp3(q_res, q_mp3) os.remove(q_res) # clean wav media_files.append(q_mp3) # Remove OLD sound tags first to avoid duplicates q_field = re.sub(r'\[sound:.*?\]', '', q_field) q_field = q_field.strip() + f"
[sound:{os.path.basename(q_mp3)}]" # Update Answer a_original = str(df.loc[idx, 'Answer']) a_field = a_original if a_res and a_res != "SKIP": a_mp3 = str(Path(a_res).with_suffix('.mp3')) wav_to_mp3(a_res, a_mp3) os.remove(a_res) # clean wav media_files.append(a_mp3) # Remove OLD sound tags first a_field = re.sub(r'\[sound:.*?\]', '', a_field) a_field = a_field.strip() + f"
[sound:{os.path.basename(a_mp3)}]" # Add Note note = genanki.Note( model=my_model, fields=[q_field, a_field] ) my_deck.add_note(note) except Exception as e: print(f"Row {idx} failed: {e}") # --- Throttled Progress --- completed += 1 current_time = time.time() if completed == total or (current_time - last_update_time) > PROGRESS_THROTTLE: progress(completed / total, desc=f"Processed {completed}/{total}") last_update_time = current_time # --- Package --- package = genanki.Package(my_deck) # Deduplicate media files list package.media_files = list(set(media_files)) raw_out = os.path.join(work_dir, "output.apkg") package.write_to_file(raw_out) final_out = os.path.join(tempfile.gettempdir(), f"pocket_deck_{random.randint(1000,9999)}.apkg") shutil.copy(raw_out, final_out) return final_out, f"✅ Done! Packaged {len(package.media_files)} audio files." except Exception as e: return None, f"Critical Error: {str(e)}" finally: # --- Guaranteed Cleanup --- if os.path.exists(work_dir): shutil.rmtree(work_dir) # Also clean up the input extraction root if it exists if extract_root and os.path.exists(extract_root): shutil.rmtree(extract_root) # --- UI --- with gr.Blocks(title="Pocket TTS Anki") as app: gr.Markdown("## 🎴 Pocket TTS Anki Generator") gr.Markdown("Offline Neural Audio. Supports CSV and APKG (smart media preservation).") # State variables full_df_state = gr.State() extract_root_state = gr.State() # Holds path to unzipped APKG with gr.Row(): file_input = gr.File(label="Upload (CSV/APKG)", file_types=[".csv", ".apkg", ".zip"]) status = gr.Textbox(label="Status", interactive=False) eta_box = gr.Textbox(label="Est. Time", interactive=False) with gr.Row(): search_box = gr.Textbox(label="Search Text", placeholder="Enter text to search...") search_field = gr.Radio( choices=["Both", "Question Only", "Answer Only"], value="Both", label="Search In" ) tag_dropdown = gr.Dropdown( label="Filter by Tag", choices=["All"], value="All", interactive=True ) with gr.Row(): # New 3-Way Toggle mode_radio = gr.Radio( choices=[ "Smart Fill (Preserve Existing)", "Generate all new audio (Overwrite)", "Only generate missing (Same as Smart Fill)" ], value="Smart Fill (Preserve Existing)", label="Generation Mode" ) preview_table = gr.Dataframe( label="Preview (First 100)", interactive=False, column_widths=["30%", "45%", "25%"] ) with gr.Row(): btn = gr.Button("🚀 Generate Deck", variant="primary") dl = gr.File(label="Download") result_lbl = gr.Textbox(label="Result", interactive=False) has_media_state = gr.State(False) def on_upload(file): # Returns: df, has_media, preview, msg, eta, extract_path df, has_media, preview, msg, eta, ext_path = parse_file(file) # Extract tags and create cleaned preview tag_choices = extract_unique_tags(df) if df is not None: display_df = df.copy() display_df['Question'] = display_df['Question'].apply(strip_html_for_display) display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display) clean_preview = display_df.head(PREVIEW_LIMIT) else: clean_preview = preview return ( df, # full_df_state has_media, # has_media_state clean_preview, # preview_table msg, # status eta, # eta_box ext_path, # extract_root_state gr.Dropdown(choices=tag_choices, value="All") # tag_dropdown ) file_input.upload(on_upload, inputs=file_input, outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown]) def on_clear(): """Reset all fields when file is cleared.""" return ( None, # full_df_state False, # has_media_state None, # preview_table "", # status "", # eta_box None, # extract_root_state gr.Dropdown(choices=["All"], value="All"), # tag_dropdown "", # search_box None, # dl (download file) "" # result_lbl ) file_input.clear(on_clear, inputs=[], outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown, search_box, dl, result_lbl]) def on_search(term, df, has_media, mode, search_in, selected_tag): if df is None: return None, "No data" filtered_df = df.copy() # Apply tag filter first if selected_tag and selected_tag != "All": filtered_df = filtered_df[filtered_df['Tags'].str.contains(selected_tag, na=False, case=False)] # Apply text search if term: if search_in == "Question Only": mask = filtered_df['Question'].str.contains(term, case=False, na=False) elif search_in == "Answer Only": mask = filtered_df['Answer'].str.contains(term, case=False, na=False) else: # Both mask = filtered_df.astype(str).apply(lambda x: x.str.contains(term, case=False, na=False)).any(axis=1) filtered_df = filtered_df[mask] # Create cleaned display version display_df = filtered_df.copy() display_df['Question'] = display_df['Question'].apply(strip_html_for_display) display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display) return display_df.head(PREVIEW_LIMIT), estimate_time(len(filtered_df), has_media, mode) search_box.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], outputs=[preview_table, eta_box]) search_field.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], outputs=[preview_table, eta_box]) tag_dropdown.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], outputs=[preview_table, eta_box]) mode_radio.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], outputs=[preview_table, eta_box]) btn.click(process_dataframe, inputs=[full_df_state, search_box, extract_root_state, mode_radio, search_field, tag_dropdown], outputs=[dl, result_lbl]) if __name__ == "__main__": app.queue(max_size=2).launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False )