DatasetChecker / analysis /standard.py
archivartaunik's picture
Update analysis/standard.py
a7d909c verified
"""Standard analysis with batch mode support."""
import os
import time
import tempfile
import re
import gradio as gr
import soundfile as sf
from google import genai
import utils
from core.state import get_global_results, set_global_results
from core.cache import get_cached_dataset, cache_dataset
from core.comparison import select_best_model_result
from ui.dashboard import generate_dashboard_outputs
from gemini_api import GeminiIntegrator, BatchTask, DEFAULT_TRANSCRIPTION_PROMPT
from hf_asr import is_hf_asr_model, get_hf_asr_client, HF_BATCH_SIZE
def sanitize_filename(name):
"""Sanitize string to be used as a filename."""
if not name:
return "results"
# Replace non-alphanumeric with underscore
s = re.sub(r'[^\w\s-]', '_', name).strip().lower()
# Replace whitespace with underscore
s = re.sub(r'[-\s]+', '_', s)
return s
def run_analysis(
api_key: str,
dataset_name: str,
model_name: str,
limit_files: int,
temperature: float,
thinking_budget: int,
similarity_threshold: int,
batch_mode: bool = False,
recheck_problematic: bool = False,
progress=gr.Progress()
):
global_results = get_global_results()
# Robust type conversion for Gradio inputs
limit_files = int(float(limit_files)) if limit_files else 0
thinking_budget = int(float(thinking_budget)) if thinking_budget else 0
similarity_threshold = int(float(similarity_threshold)) if similarity_threshold else 90
temperature = float(temperature)
# ---------------------------------------------------------
# HUGGING FACE ASR MODE
# ---------------------------------------------------------
if is_hf_asr_model(model_name):
return _run_hf_asr_analysis(
model_name, dataset_name, limit_files,
similarity_threshold, recheck_problematic, progress
)
# For Gemini models, API key is required
if not api_key:
raise gr.Error("Калі ласка, увядзіце Gemini API ключ.")
try:
gemini_tool = GeminiIntegrator(api_key=api_key)
config_args = {"temperature": temperature}
use_thinking = "thinking" in model_name
if use_thinking and thinking_budget > 0:
config_args["thinking_config"] = {
"include_thoughts": True,
"budget_tokens": thinking_budget
}
gen_config = genai.types.GenerateContentConfig(**config_args)
# ---------------------------------------------------------
# BATCH MODE
# ---------------------------------------------------------
if batch_mode:
return _run_batch_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, recheck_problematic, progress
)
# ---------------------------------------------------------
# STANDARD SYNC MODE
# ---------------------------------------------------------
if recheck_problematic:
return _run_recheck_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, gen_config, progress
)
else:
return _run_fresh_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, gen_config, progress
)
except Exception as e:
raise gr.Error(f"Памылка: {e}")
def _run_batch_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, recheck_problematic, progress
):
"""Run batch processing mode."""
global_results = get_global_results()
ds = None
# 1. Prepare Data
if recheck_problematic:
if not global_results:
gr.Warning("Няма вынікаў для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Load full dataset for audio fallback
limit = None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds:
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
else:
limit = int(limit_files) if limit_files > 0 else None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds:
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
# Init results if fresh run
progress(0.1, desc="Ініцыялізацыя спісу...")
new_results = []
for idx, item in enumerate(ds):
ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or ""
new_results.append({
"id": idx,
"path": item['audio']['path'],
"ref_text": ref_text,
"hyp_text": "",
"score": 0,
"audio_array": item['audio']['array'],
"sampling_rate": item['audio']['sampling_rate'],
"model_used": model_name,
"verification_status": "pending"
})
set_global_results(new_results)
global_results = get_global_results()
# 2. Prepare Tasks
tasks = []
tmp_dir_obj = tempfile.TemporaryDirectory()
tmp_dir = tmp_dir_obj.name
def prepare_task(idx, row_data, audio_ref):
key = f"task_{idx}"
fpath = audio_ref['audio']['path']
# Verify file existence or dump numpy to WAV
if not fpath or not os.path.exists(fpath):
audio_arr = audio_ref['audio']['array']
sr = audio_ref['audio']['sampling_rate']
if len(audio_arr) == 0:
return None
clean_name = sanitize_filename(f"audio_{idx}")
dump_path = os.path.join(tmp_dir, f"{clean_name}.wav")
sf.write(dump_path, audio_arr, int(sr), format='WAV')
fpath = dump_path
return BatchTask(key=key, path=fpath, mime_type="audio/wav")
progress(0.2, desc="Падрыхтоўка задач для пакетнага рэжыму...")
task_map_idx = {} # task_key -> result_index
if recheck_problematic:
# Identification logic
target_indices = [
i for i, r in enumerate(global_results)
if r['score'] < similarity_threshold
and r.get('verification_status') != 'correct'
]
if limit_files > 0:
target_indices = target_indices[:limit_files]
if not target_indices:
gr.Info("Няма праблемных файлаў для пераправеркі.")
try: tmp_dir_obj.cleanup()
except: pass
return generate_dashboard_outputs(similarity_threshold)
# Create DS Map
ds_map = {}
for di, d_item in enumerate(ds):
p = d_item['audio']['path']
if p:
ds_map[p] = d_item
ds_map[os.path.basename(p)] = d_item
ds_map[di] = d_item
for global_res_idx in target_indices:
res = global_results[global_res_idx]
path = res.get('path', '')
# Try finding item
item = ds_map.get(path) or ds_map.get(os.path.basename(path))
if not item and res.get('id') is not None:
try: item = ds[int(res.get('id'))]
except: pass
if item:
t = prepare_task(global_res_idx, res, item)
if t:
tasks.append(t)
task_map_idx[t.key] = global_res_idx
else:
# Tasks for all
for idx, res in enumerate(global_results):
item = ds[idx]
t = prepare_task(idx, res, item)
if t:
tasks.append(t)
task_map_idx[t.key] = idx
if not tasks:
gr.Warning("Не знойдзена задач для выканання (магчыма, адсутнічае аўдыя).")
try: tmp_dir_obj.cleanup()
except: pass
return generate_dashboard_outputs(similarity_threshold)
# 3. Execute Batch
progress(0.3, desc=f"Запуск пакетнай апрацоўкі ({len(tasks)} файлаў). Гэта зойме час...")
prompt = DEFAULT_TRANSCRIPTION_PROMPT
try:
batch_results = gemini_tool.run_batch(tasks, model_name, prompt)
except Exception as e:
try: tmp_dir_obj.cleanup()
except: pass
raise gr.Error(f"Batch failed: {e}")
progress(0.9, desc="Апрацоўка вынікаў...")
# 4. Map Results
for key, text in batch_results.items():
if key in task_map_idx:
idx = task_map_idx[key]
if idx < len(global_results):
ref_text = global_results[idx]['ref_text']
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, text)
global_results[idx].update({
"hyp_text": text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp,
"verification_status": "correct" if score >= similarity_threshold else "incorrect",
"model_used": f"batch_{model_name}"
})
if 'model_results' not in global_results[idx]:
global_results[idx]['model_results'] = {}
global_results[idx]['model_results'][model_name] = {
"hyp_text": text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
try: tmp_dir_obj.cleanup()
except: pass
return generate_dashboard_outputs(similarity_threshold)
def _run_recheck_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, gen_config, progress
):
"""Run recheck of problematic files."""
global_results = get_global_results()
if not global_results:
gr.Warning("Няма вынікаў для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Identify problematic records
target_indices = [
i for i, r in enumerate(global_results)
if r['score'] < similarity_threshold
and r.get('verification_status') != 'correct'
]
if limit_files > 0:
target_indices = target_indices[:limit_files]
if not target_indices:
gr.Info("Няма праблемных файлаў для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Load dataset to get audio for files that might be missing it
limit = None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.05, desc=f"Датасет закэшаваны")
# Build audio map by filename
audio_map = {}
for item in ds:
path = item['audio']['path']
if path:
fname = os.path.basename(path)
audio_map[fname] = item
audio_map[path] = item
progress(0.1, desc=f"Пераправерка {len(target_indices)} файлаў...")
for j, idx in enumerate(target_indices):
progress(0.1 + (j + 1) / len(target_indices) * 0.9, desc=f"Праверка {j+1}/{len(target_indices)}")
result = global_results[idx]
audio_data = result.get('audio_array')
sampling_rate = result.get('sampling_rate')
ref_text = result.get('ref_text', "")
# If audio is missing, try to fetch from dataset
if audio_data is None or len(audio_data) == 0:
path = result.get('path', '')
item = audio_map.get(path) or audio_map.get(os.path.basename(path))
# Fallback: try to find by ID if path lookup failed
if not item:
rec_id = result.get('id')
if rec_id is not None:
try:
rec_id = int(rec_id)
if 0 <= rec_id < len(ds):
item = ds[rec_id]
except:
pass
if item:
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
global_results[idx]['audio_array'] = audio_data
global_results[idx]['sampling_rate'] = sampling_rate
else:
print(f"Problematic Recheck: Skipping index {idx}, path '{path}', id {result.get('id')}: Audio not found in dataset.")
continue
hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config)
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
print(f"🔄 Updated: {result.get('path')} | Score: {result.get('score')} -> {score} | Text: {hyp_text}")
# Save model result
if 'model_results' not in global_results[idx]:
global_results[idx]['model_results'] = {}
global_results[idx]['model_results'][model_name] = {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
# Select best result from all models
best_model, best_result = select_best_model_result(
global_results[idx]['model_results'],
similarity_threshold
)
if best_result:
global_results[idx].update({
"hyp_text": best_result['hyp_text'],
"score": best_result['score'],
"norm_ref": best_result['norm_ref'],
"norm_hyp": best_result['norm_hyp'],
"model_used": best_model,
"verification_status": "correct" if best_result['score'] >= similarity_threshold else "incorrect"
})
return generate_dashboard_outputs(similarity_threshold)
def _run_fresh_analysis(
gemini_tool, model_name, dataset_name, limit_files,
similarity_threshold, gen_config, progress
):
"""Run fresh analysis on all files."""
limit = int(limit_files) if limit_files > 0 else None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.1, desc=f"Датасет закэшаваны для паўторнага выкарыстання")
results = []
for idx, item in enumerate(ds):
progress((idx + 1) / len(ds), desc=f"Апрацоўка файла {idx+1}/{len(ds)}")
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or ""
hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config)
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
results.append({
"id": idx,
"path": item['audio']['path'],
"ref_text": ref_text,
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp,
"audio_array": audio_data,
"sampling_rate": sampling_rate,
"model_used": model_name,
"verification_status": "correct" if score >= similarity_threshold else "incorrect",
"model_results": {
model_name: {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
}
})
set_global_results(results)
return generate_dashboard_outputs(similarity_threshold)
def _run_hf_asr_analysis(
model_name: str,
dataset_name: str,
limit_files: int,
similarity_threshold: int,
recheck_problematic: bool,
progress
):
"""Run analysis using Hugging Face ASR model."""
global_results = get_global_results()
try:
hf_client = get_hf_asr_client(model_name)
progress(0.05, desc=f"Падключэнне да HF Space: {model_name}...")
except Exception as e:
raise gr.Error(f"Памылка падключэння да HF: {e}")
if recheck_problematic:
return _run_hf_recheck_analysis(
hf_client, model_name, dataset_name, limit_files,
similarity_threshold, progress
)
else:
return _run_hf_fresh_analysis(
hf_client, model_name, dataset_name, limit_files,
similarity_threshold, progress
)
def _run_hf_fresh_analysis(
hf_client, model_name, dataset_name, limit_files,
similarity_threshold, progress
):
"""Run fresh analysis using HF ASR with batch processing."""
limit = int(limit_files) if limit_files > 0 else None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.1, desc=f"Датасет закэшаваны для паўторнага выкарыстання")
# Pre-collect all items with their data
all_items = []
for idx, item in enumerate(ds):
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or ""
all_items.append({
"idx": idx,
"path": item['audio']['path'],
"audio_data": audio_data,
"sampling_rate": sampling_rate,
"ref_text": ref_text
})
total_items = len(all_items)
results = [None] * total_items # Pre-allocate for correct ordering
# Process in batches of HF_BATCH_SIZE (100)
batch_size = HF_BATCH_SIZE
num_batches = (total_items + batch_size - 1) // batch_size
for batch_num in range(num_batches):
# Delay between batches to avoid rate limiting (skip for first batch)
if batch_num > 0:
print(f"⏳ Чакаем 5с перад наступным пакетам...")
time.sleep(5)
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, total_items)
batch_items = all_items[start_idx:end_idx]
progress_val = 0.1 + (batch_num / num_batches) * 0.9
progress(progress_val, desc=f"Пакет {batch_num + 1}/{num_batches}: апрацоўка {len(batch_items)} файлаў (HF ASR)...")
# Prepare batch for transcription: (key, audio_array, sampling_rate)
batch_audio = [
(item["idx"], item["audio_data"], item["sampling_rate"])
for item in batch_items
]
# Send batch to HF ASR (retry logic is inside transcribe_batch)
transcriptions = hf_client.transcribe_batch(batch_audio)
# Process results - only save successful transcriptions
transcribed_count = 0
for item in batch_items:
idx = item["idx"]
ref_text = item["ref_text"]
hyp_text = transcriptions.get(idx, "")
# Only record result if transcription was successful
if hyp_text:
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
transcribed_count += 1
results[idx] = {
"id": idx,
"path": item["path"],
"ref_text": ref_text,
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp,
"audio_array": item["audio_data"],
"sampling_rate": item["sampling_rate"],
"model_used": model_name,
"verification_status": "correct" if score >= similarity_threshold else "incorrect",
"model_results": {
model_name: {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
}
}
# Skip items with no transcription result
print(f"✅ Пакет {batch_num + 1}/{num_batches} завершаны: {transcribed_count}/{len(batch_items)} транскрыбавана")
set_global_results(results)
return generate_dashboard_outputs(similarity_threshold)
def _run_hf_recheck_analysis(
hf_client, model_name, dataset_name, limit_files,
similarity_threshold, progress
):
"""Run recheck of problematic files using HF ASR with batch processing."""
global_results = get_global_results()
if not global_results:
gr.Warning("Няма вынікаў для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Identify problematic records
target_indices = [
i for i, r in enumerate(global_results)
if r['score'] < similarity_threshold
and r.get('verification_status') != 'correct'
]
if limit_files > 0:
target_indices = target_indices[:limit_files]
if not target_indices:
gr.Info("Няма праблемных файлаў для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Load dataset to get audio for files that might be missing it
limit = None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.05, desc=f"Датасет закэшаваны")
# Build audio map by filename
audio_map = {}
for item in ds:
path = item['audio']['path']
if path:
fname = os.path.basename(path)
audio_map[fname] = item
audio_map[path] = item
# Collect all items to process with their audio data
items_to_process = []
for idx in target_indices:
result = global_results[idx]
audio_data = result.get('audio_array')
sampling_rate = result.get('sampling_rate')
ref_text = result.get('ref_text', "")
# If audio is missing, try to fetch from dataset
if audio_data is None or (hasattr(audio_data, '__len__') and len(audio_data) == 0):
path = result.get('path', '')
item = audio_map.get(path) or audio_map.get(os.path.basename(path))
# Fallback: try to find by ID if path lookup failed
if not item:
rec_id = result.get('id')
if rec_id is not None:
try:
rec_id = int(rec_id)
if 0 <= rec_id < len(ds):
item = ds[rec_id]
except:
pass
if item:
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
global_results[idx]['audio_array'] = audio_data
global_results[idx]['sampling_rate'] = sampling_rate
else:
print(f"HF Recheck: Skipping index {idx}, path '{path}': Audio not found.")
continue
items_to_process.append({
"idx": idx,
"audio_data": audio_data,
"sampling_rate": sampling_rate,
"ref_text": ref_text
})
if not items_to_process:
gr.Info("Няма файлаў з аўдыя для пераправеркі.")
return generate_dashboard_outputs(similarity_threshold)
# Process in batches of HF_BATCH_SIZE (100)
batch_size = HF_BATCH_SIZE
total_items = len(items_to_process)
num_batches = (total_items + batch_size - 1) // batch_size
progress(0.1, desc=f"Пераправерка {total_items} файлаў у {num_batches} пакетах (HF ASR)...")
for batch_num in range(num_batches):
# Delay between batches to avoid rate limiting (skip for first batch)
if batch_num > 0:
print(f"⏳ Чакаем 5с перад наступным пакетам...")
time.sleep(5)
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, total_items)
batch_items = items_to_process[start_idx:end_idx]
progress_val = 0.1 + (batch_num / num_batches) * 0.9
progress(progress_val, desc=f"Пакет {batch_num + 1}/{num_batches}: апрацоўка {len(batch_items)} файлаў...")
# Prepare batch for transcription: (key, audio_array, sampling_rate)
batch_audio = [
(item["idx"], item["audio_data"], item["sampling_rate"])
for item in batch_items
]
# Send batch to HF ASR (retry logic is inside transcribe_batch)
transcriptions = hf_client.transcribe_batch(batch_audio)
# Process results - only save successful transcriptions
transcribed_count = 0
for item in batch_items:
idx = item["idx"]
ref_text = item["ref_text"]
hyp_text = transcriptions.get(idx, "")
if not hyp_text:
continue
transcribed_count += 1
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
print(f"🔄 HF Updated: {global_results[idx].get('path')} | Score: {global_results[idx].get('score')} -> {score}")
# Save model result
if 'model_results' not in global_results[idx]:
global_results[idx]['model_results'] = {}
global_results[idx]['model_results'][model_name] = {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
# Select best result from all models
best_model, best_result = select_best_model_result(
global_results[idx]['model_results'],
similarity_threshold
)
if best_result:
global_results[idx].update({
"hyp_text": best_result['hyp_text'],
"score": best_result['score'],
"norm_ref": best_result['norm_ref'],
"norm_hyp": best_result['norm_hyp'],
"model_used": best_model,
"verification_status": "correct" if best_result['score'] >= similarity_threshold else "incorrect"
})
print(f"✅ Пакет {batch_num + 1}/{num_batches} завершаны: {transcribed_count}/{len(batch_items)} транскрыбавана")
return generate_dashboard_outputs(similarity_threshold)