archivartaunik's picture
Upload 35 files
e82eaee verified
"""Smart analysis - multi-model iterative processing."""
import os
import gradio as gr
from google import genai
import utils
from core.state import get_global_results, set_global_results
from core.cache import get_cached_dataset, cache_dataset
from core.comparison import select_best_model_result
from ui.dashboard import generate_dashboard_outputs
from gemini_api import GeminiIntegrator
def run_smart_analysis(
api_key: str,
dataset_name: str,
limit_files: int,
temperature: float,
thinking_budget: int,
similarity_threshold: int,
recheck_problematic: bool = False,
progress=gr.Progress()
):
global_results = get_global_results()
# Robust type conversion for Gradio inputs
limit_files = int(float(limit_files)) if limit_files else 0
thinking_budget = int(float(thinking_budget)) if thinking_budget else 0
similarity_threshold = int(float(similarity_threshold)) if similarity_threshold else 90
temperature = float(temperature)
if not api_key:
raise gr.Error("Калі ласка, увядзіце Gemini API ключ.")
models = [
("gemini-2.5-flash-lite", "Этап 1/4: Flash-Lite (першы праход)"),
("gemini-2.5-flash-lite", "Этап 2/4: Flash-Lite (другі праход)"),
("gemini-2.5-flash", "Этап 3/4: Flash"),
("gemini-3-flash-preview", "Этап 4/4: Gemini-3-Flash"),
]
try:
gemini_tool = GeminiIntegrator(api_key=api_key)
config_args = {"temperature": temperature}
gen_config = genai.types.GenerateContentConfig(**config_args)
results = []
# STEP 1: Initialization / First Pass
step_desc = models[0][1]
model_name = models[0][0]
if recheck_problematic:
results = _smart_recheck_first_pass(
gemini_tool, model_name, step_desc, dataset_name,
limit_files, similarity_threshold, gen_config, progress
)
else:
results = _smart_fresh_first_pass(
gemini_tool, model_name, step_desc, dataset_name,
limit_files, similarity_threshold, gen_config, progress
)
# STEP 2-4: Iterative improvement
base_progress = 0.25
step_progress_size = 0.25
for step_idx in range(1, len(models)):
model_name = models[step_idx][0]
step_desc = models[step_idx][1]
# Find items that are STILL problematic AND not verified correct
problematic_indices = [
i for i, r in enumerate(results)
if r['score'] < similarity_threshold
and r.get('verification_status') != 'correct'
]
if not problematic_indices:
progress(base_progress + step_idx * step_progress_size,
desc=f"{step_desc}: няма праблемных запісаў, прапускаем...")
continue
progress(base_progress + (step_idx - 1) * step_progress_size,
desc=f"{step_desc}: пераапрацоўка {len(problematic_indices)} праблемных запісаў...")
for j, res_idx in enumerate(problematic_indices):
progress(base_progress + (step_idx - 1) * step_progress_size + (j + 1) / len(problematic_indices) * step_progress_size,
desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}")
result = results[res_idx]
audio_data = result.get('audio_array')
sampling_rate = result.get('sampling_rate')
ref_text = result.get('ref_text', "")
if audio_data is None or len(audio_data) == 0:
continue
hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config)
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
# Save model result
if 'model_results' not in results[res_idx]:
results[res_idx]['model_results'] = {}
results[res_idx]['model_results'][model_name] = {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
# Select best result
best_model, best_result = select_best_model_result(
results[res_idx]['model_results'],
similarity_threshold
)
if best_result and (best_result['score'] > result['score'] or best_result['score'] >= similarity_threshold):
new_status = "correct" if best_result['score'] >= similarity_threshold else "incorrect"
print(f"✅ UPDATE APPLIED [Idx={res_idx}]: {result.get('path')} | Best model: {best_model} | Score: {result['score']} -> {best_result['score']}")
results[res_idx].update({
"hyp_text": best_result['hyp_text'],
"score": best_result['score'],
"norm_ref": best_result['norm_ref'],
"norm_hyp": best_result['norm_hyp'],
"model_used": best_model,
"verification_status": new_status
})
else:
print(f"⏭️ SKIP UPDATE [Idx={res_idx}]: Best score {best_result['score'] if best_result else 'N/A'} is not better than {result.get('score')} and not meeting threshold {similarity_threshold}")
set_global_results(results)
return generate_dashboard_outputs(similarity_threshold)
except Exception as e:
raise gr.Error(f"Памылка: {e}")
def _smart_recheck_first_pass(
gemini_tool, model_name, step_desc, dataset_name,
limit_files, similarity_threshold, gen_config, progress
):
"""First pass for recheck mode."""
global_results = get_global_results()
if not global_results:
gr.Warning("Няма вынікаў для пераправеркі.")
return []
results = global_results
# Identify start set: only problematic items
problematic_indices = [
i for i, r in enumerate(results)
if r['score'] < similarity_threshold
and r.get('verification_status') != 'correct'
]
if limit_files > 0:
problematic_indices = problematic_indices[:limit_files]
if not problematic_indices:
gr.Info("Няма праблемных файлаў для пераправеркі.")
return results
# Load dataset to get audio for files that might be missing it
limit = None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.03, desc=f"Датасет закэшаваны")
# Build audio map by filename
audio_map = {}
for item in ds:
path = item['audio']['path']
if path:
fname = os.path.basename(path)
audio_map[fname] = item
audio_map[path] = item
progress(0.05, desc=f"{step_desc}: пераправерка {len(problematic_indices)} запісаў...")
for j, res_idx in enumerate(problematic_indices):
progress(0.05 + (j + 1) / len(problematic_indices) * 0.20, desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}")
result = results[res_idx]
audio_data = result.get('audio_array')
sampling_rate = result.get('sampling_rate')
ref_text = result.get('ref_text', "")
# If audio is missing, try to fetch from dataset
if audio_data is None or len(audio_data) == 0:
path = result.get('path', '')
item = audio_map.get(path) or audio_map.get(os.path.basename(path))
if not item:
rec_id = result.get('id')
if rec_id is not None:
try:
rec_id = int(rec_id)
if 0 <= rec_id < len(ds):
item = ds[rec_id]
except:
pass
if item:
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
results[res_idx]['audio_array'] = audio_data
results[res_idx]['sampling_rate'] = sampling_rate
else:
print(f"Smart Analysis Recheck: Skipping index {res_idx}, path '{path}', id {result.get('id')}: Audio not found.")
continue
hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config)
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
print(f"🔄 Smart Updated (Step 1): {result.get('path')} | Score: {result.get('score')} -> {score} | Text: {hyp_text}")
if 'model_results' not in results[res_idx]:
results[res_idx]['model_results'] = {}
results[res_idx]['model_results'][model_name] = {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
best_model, best_result = select_best_model_result(
results[res_idx]['model_results'],
similarity_threshold
)
if best_result:
results[res_idx].update({
"hyp_text": best_result['hyp_text'],
"score": best_result['score'],
"norm_ref": best_result['norm_ref'],
"norm_hyp": best_result['norm_hyp'],
"model_used": best_model,
"verification_status": "correct" if best_result['score'] >= similarity_threshold else "incorrect"
})
return results
def _smart_fresh_first_pass(
gemini_tool, model_name, step_desc, dataset_name,
limit_files, similarity_threshold, gen_config, progress
):
"""First pass for fresh analysis."""
limit = int(limit_files) if limit_files > 0 else None
cached_ds = get_cached_dataset(dataset_name, limit)
if cached_ds is not None:
progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...")
ds = cached_ds
else:
progress(0, desc=f"Загрузка датасета '{dataset_name}'...")
ds = utils.load_hf_dataset(dataset_name, limit=limit)
cache_dataset(dataset_name, limit, ds)
progress(0.05, desc=f"Датасет закэшаваны для паўторнага выкарыстання")
results = []
progress(0.05, desc=f"{step_desc}: апрацоўка ўсіх {len(ds)} запісаў...")
for idx, item in enumerate(ds):
progress(0.05 + (idx + 1) / len(ds) * 0.20, desc=f"{step_desc}: файл {idx+1}/{len(ds)}")
audio_data = item['audio']['array']
sampling_rate = item['audio']['sampling_rate']
ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or ""
hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config)
score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text)
results.append({
"id": idx,
"path": item['audio']['path'],
"ref_text": ref_text,
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp,
"audio_array": audio_data,
"sampling_rate": sampling_rate,
"model_used": model_name,
"verification_status": "correct" if score >= similarity_threshold else "incorrect",
"model_results": {
model_name: {
"hyp_text": hyp_text,
"score": score,
"norm_ref": norm_ref,
"norm_hyp": norm_hyp
}
}
})
return results