|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoModelForImageClassification, AutoModelForObjectDetection
|
|
|
from PIL import Image
|
|
|
from pydub import AudioSegment
|
|
|
import soundfile as sf
|
|
|
import numpy as np
|
|
|
import io
|
|
|
from torch.utils.data import DataLoader
|
|
|
from datasets import load_dataset, Audio, concatenate_datasets, Dataset, DatasetDict
|
|
|
from torch.optim import AdamW
|
|
|
from accelerate import Accelerator
|
|
|
from tqdm.auto import tqdm
|
|
|
from transformers import (
|
|
|
AutoFeatureExtractor,
|
|
|
AutoModelForSpeechSeq2Seq,
|
|
|
Seq2SeqTrainingArguments,
|
|
|
Seq2SeqTrainer,
|
|
|
DataCollatorForSeq2Seq, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
|
|
|
)
|
|
|
import evaluate
|
|
|
import logging
|
|
|
import re
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
from langchain_community.document_loaders import PyPDFLoader
|
|
|
from langchain.chains import RetrievalQA
|
|
|
from gtts import gTTS
|
|
|
import tempfile
|
|
|
import time
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
MODEL_NAME = "bert-base-uncased"
|
|
|
DATASET_NAME = "glue"
|
|
|
SUBSET_NAME = "sst2"
|
|
|
NUM_EPOCHS = 3
|
|
|
BATCH_SIZE = 8
|
|
|
LEARNING_RATE = 2e-5
|
|
|
MAX_SEQ_LENGTH = 128
|
|
|
SAVE_MODEL_EVERY_EPOCH = False
|
|
|
|
|
|
|
|
|
MODEL_CHECKPOINT = "bert-base-uncased"
|
|
|
TRAIN_SPLIT = "train"
|
|
|
NUM_TRAIN_EPOCHS = 3
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset_streaming(dataset_name, subset_name=None, split="train"):
|
|
|
logging.info(f"Loading dataset '{dataset_name}' (subset: {subset_name}) in streaming mode for split '{split}'...")
|
|
|
try:
|
|
|
if subset_name:
|
|
|
dataset = load_dataset(dataset_name, subset_name, split=split, streaming=True)
|
|
|
else:
|
|
|
dataset = load_dataset(dataset_name, split=split, streaming=True)
|
|
|
logging.info(f"Dataset '{dataset_name}' loaded successfully in streaming mode.")
|
|
|
return dataset
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error loading dataset '{dataset_name}': {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_function(examples, tokenizer, max_seq_length):
|
|
|
"""Tokenizes and prepares text data for the model."""
|
|
|
return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=max_seq_length)
|
|
|
|
|
|
|
|
|
|
|
|
def train_model_with_streaming_data():
|
|
|
accelerator = Accelerator()
|
|
|
logging.info(f"Accelerator device: {accelerator.device}")
|
|
|
|
|
|
logging.info(f"Loading dataset 'glue' (subset: sst2) in streaming mode for split '{TRAIN_SPLIT}'...")
|
|
|
train_dataset = load_dataset('glue', 'sst2', split=TRAIN_SPLIT,
|
|
|
streaming=True)
|
|
|
logging.info("Dataset 'glue' loaded successfully in streaming mode.")
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_CHECKPOINT, num_labels=2)
|
|
|
|
|
|
|
|
|
def process_function(examples):
|
|
|
tokenized_examples = tokenizer(examples["sentence"], truncation=True)
|
|
|
|
|
|
tokenized_examples["labels"] = examples["label"]
|
|
|
return tokenized_examples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenized_train_dataset = train_dataset.map(
|
|
|
process_function, batched=True, remove_columns=["sentence", "idx", "label"]
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
data_collator_for_classification = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
logging.info(f"Loading dataset 'glue' (subset: sst2) in streaming mode for split '{TRAIN_SPLIT}'...")
|
|
|
train_dataset = load_dataset('glue', 'sst2', split=TRAIN_SPLIT,
|
|
|
streaming=True)
|
|
|
logging.info("Dataset 'glue' loaded successfully in streaming mode.")
|
|
|
|
|
|
|
|
|
tokenized_train_dataset = train_dataset.map(
|
|
|
process_function, batched=True, remove_columns=["sentence", "idx", "label"]
|
|
|
)
|
|
|
|
|
|
|
|
|
tokenized_train_dataset = tokenized_train_dataset.with_format("torch")
|
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(
|
|
|
tokenized_train_dataset,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
collate_fn=data_collator_for_classification
|
|
|
)
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
accelerator = Accelerator()
|
|
|
model, optimizer, train_dataloader = accelerator.prepare(
|
|
|
model, optimizer, train_dataloader
|
|
|
)
|
|
|
|
|
|
model.train()
|
|
|
for epoch in range(NUM_TRAIN_EPOCHS):
|
|
|
logging.info(f"Starting Epoch {epoch + 1}/{NUM_TRAIN_EPOCHS}")
|
|
|
for i, batch in enumerate(tqdm(train_dataloader)):
|
|
|
if "label" in batch:
|
|
|
batch.pop("label")
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
accelerator.backward(loss)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if i % 100 == 0:
|
|
|
logging.info(f"Step {i}, Loss: {loss.item()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = load_dataset_streaming(DATASET_NAME, SUBSET_NAME, split="train")
|
|
|
|
|
|
if train_dataset is None:
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenized_train_dataset = train_dataset.map(
|
|
|
lambda examples: preprocess_function(examples, tokenizer, MAX_SEQ_LENGTH),
|
|
|
batched=True,
|
|
|
remove_columns=["sentence", "idx"]
|
|
|
)
|
|
|
|
|
|
|
|
|
tokenized_train_dataset = tokenized_train_dataset.with_format("torch")
|
|
|
|
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(tokenized_train_dataset.shuffle(seed=42, buffer_size=10_000), batch_size=BATCH_SIZE)
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
|
|
|
|
model, optimizer, train_dataloader = accelerator.prepare(
|
|
|
model, optimizer, train_dataloader
|
|
|
)
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
total_steps = 0
|
|
|
for epoch in range(NUM_EPOCHS):
|
|
|
logging.info(f"Starting Epoch {epoch + 1}/{NUM_EPOCHS}")
|
|
|
|
|
|
|
|
|
if hasattr(train_dataloader.dataset, 'set_epoch'):
|
|
|
train_dataloader.dataset.set_epoch(epoch)
|
|
|
|
|
|
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", disable=not accelerator.is_local_main_process)
|
|
|
|
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
|
|
|
|
|
|
|
if "label" in batch:
|
|
|
batch.pop("label")
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
accelerator.backward(loss)
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
progress_bar.set_postfix({'loss': loss.item()})
|
|
|
total_steps += 1
|
|
|
|
|
|
|
|
|
if total_steps % 100 == 0:
|
|
|
accelerator.log({"train_loss": loss.item()}, step=total_steps)
|
|
|
|
|
|
logging.info(f"Epoch {epoch + 1} finished. Avg Loss: {loss.item()}")
|
|
|
|
|
|
if SAVE_MODEL_EVERY_EPOCH:
|
|
|
output_dir = f"./model_epoch_{epoch + 1}"
|
|
|
accelerator.wait_for_everyone()
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
|
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
logging.info(f"Model saved to {output_dir}")
|
|
|
|
|
|
logging.info("Training complete.")
|
|
|
|
|
|
|
|
|
final_output_dir = "./final_model"
|
|
|
accelerator.wait_for_everyone()
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
unwrapped_model.save_pretrained(final_output_dir, save_function=accelerator.save)
|
|
|
tokenizer.save_pretrained(final_output_dir)
|
|
|
logging.info(f"Final model saved to {final_output_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
train_model_with_streaming_data()
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
|
_pipeline_cache = {}
|
|
|
|
|
|
|
|
|
def get_pipeline(task_name, model_name=None, **kwargs):
|
|
|
"""
|
|
|
Retrieves a Hugging Face pipeline, caching it for efficiency.
|
|
|
"""
|
|
|
cache_key = f"{task_name}-{model_name}-{hash(frozenset(kwargs.items()))}"
|
|
|
if cache_key not in _pipeline_cache:
|
|
|
logging.info(f"Loading pipeline for task '{task_name}' with model '{model_name}'...")
|
|
|
if model_name:
|
|
|
_pipeline_cache[cache_key] = pipeline(task_name, model=model_name, **kwargs)
|
|
|
else:
|
|
|
_pipeline_cache[cache_key] = pipeline(task_name, **kwargs)
|
|
|
logging.info(f"Pipeline '{task_name}' loaded.")
|
|
|
return _pipeline_cache[cache_key]
|
|
|
|
|
|
|
|
|
class AIToolDispatcher:
|
|
|
def __init__(self):
|
|
|
logging.info("Initializing AI Tool Dispatcher...")
|
|
|
|
|
|
|
|
|
self.default_models = {
|
|
|
"text-classification": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
"sentiment-analysis": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
|
|
|
"summarization": "sshleifer/distilbart-cnn-12-6",
|
|
|
"text-generation": "gpt2",
|
|
|
"translation_en_to_fr": "Helsinki-NLP/opus-mt-en-fr",
|
|
|
"image-classification": "google/vit-base-patch16-224",
|
|
|
"object-detection": "facebook/detr-resnet-50",
|
|
|
"automatic-speech-recognition": "openai/whisper-tiny.en",
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
logging.info("Dispatcher initialized.")
|
|
|
|
|
|
def _get_task_pipeline(self, task: str, model_name: str = None):
|
|
|
"""Helper to get a cached pipeline for a specific task."""
|
|
|
final_model_name = model_name if model_name else self.default_models.get(task)
|
|
|
if not final_model_name:
|
|
|
raise ValueError(f"No default model specified for task '{task}'. Please provide `model_name`.")
|
|
|
return get_pipeline(task, model_name=final_model_name)
|
|
|
|
|
|
def process_text(self, text: str, task: str = "text-classification", **kwargs):
|
|
|
"""Processes text input for a given NLP task."""
|
|
|
if not isinstance(text, str):
|
|
|
raise TypeError("Text input must be a string.")
|
|
|
|
|
|
logging.info(f"Processing text for task: {task}")
|
|
|
pipeline = self._get_task_pipeline(task)
|
|
|
result = pipeline(text, **kwargs)
|
|
|
return result
|
|
|
|
|
|
def process_image(self, image_path: str, task: str = "image-classification", **kwargs):
|
|
|
"""Processes image file input for a given computer vision task."""
|
|
|
if not os.path.exists(image_path):
|
|
|
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
|
|
|
|
logging.info(f"Processing image for task: {task}")
|
|
|
|
|
|
try:
|
|
|
image = Image.open(image_path)
|
|
|
except Exception as e:
|
|
|
raise ValueError(f"Could not open image file: {e}")
|
|
|
|
|
|
pipeline = self._get_task_pipeline(task)
|
|
|
result = pipeline(image, **kwargs)
|
|
|
return result
|
|
|
|
|
|
def process_audio(self, audio_path: str, task: str = "automatic-speech-recognition", **kwargs):
|
|
|
"""Processes audio file input for a given audio task."""
|
|
|
if not os.path.exists(audio_path):
|
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
|
|
|
|
logging.info(f"Processing audio for task: {task}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
audio = AudioSegment.from_file(audio_path)
|
|
|
|
|
|
audio = audio.set_channels(1).set_frame_rate(16000)
|
|
|
|
|
|
|
|
|
buffer = io.BytesIO()
|
|
|
audio.export(buffer, format="wav")
|
|
|
buffer.seek(0)
|
|
|
|
|
|
|
|
|
array, sampling_rate = sf.read(buffer)
|
|
|
|
|
|
|
|
|
if array.dtype != np.float32:
|
|
|
array = array.astype(np.float32)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error processing audio file for conversion: {e}")
|
|
|
raise ValueError(f"Could not prepare audio file: {e}")
|
|
|
|
|
|
pipeline = self._get_task_pipeline(task)
|
|
|
result = pipeline(array.tolist(), sampling_rate=sampling_rate,
|
|
|
**kwargs)
|
|
|
return result
|
|
|
|
|
|
def process_input(self, input_data: str, task: str = None, **kwargs):
|
|
|
"""
|
|
|
Main entry point for the AI tool. Tries to determine input type and
|
|
|
dispatches to the appropriate processing function.
|
|
|
|
|
|
Args:
|
|
|
input_data (str): Can be raw text or a file path (for image/audio/video).
|
|
|
task (str, optional): The specific AI task to perform (e.g., "summarization",
|
|
|
"object-detection", "automatic-speech-recognition").
|
|
|
REQUIRED for non-text inputs.
|
|
|
For text, it defaults to "text-classification".
|
|
|
**kwargs: Additional arguments to pass to the specific pipeline.
|
|
|
|
|
|
Returns:
|
|
|
dict or list: The result from the AI model.
|
|
|
"""
|
|
|
if not isinstance(input_data, str):
|
|
|
raise TypeError("Input data must be a string (raw text or file path).")
|
|
|
|
|
|
|
|
|
if os.path.exists(input_data):
|
|
|
|
|
|
file_extension = input_data.split('.')[-1].lower()
|
|
|
|
|
|
if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']:
|
|
|
if not task: task = "image-classification"
|
|
|
if not task.startswith("image-"):
|
|
|
logging.warning(
|
|
|
f"Task '{task}' may not be suitable for image input. Defaulting to '{task}' anyway.")
|
|
|
return self.process_image(input_data, task=task, **kwargs)
|
|
|
|
|
|
elif file_extension in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
|
|
|
if not task: task = "automatic-speech-recognition"
|
|
|
if not task.startswith("audio-") and not task.startswith("automatic-speech-recognition"):
|
|
|
logging.warning(
|
|
|
f"Task '{task}' may not be suitable for audio input. Defaulting to '{task}' anyway.")
|
|
|
return self.process_audio(input_data, task=task, **kwargs)
|
|
|
|
|
|
elif file_extension in ['mp4', 'avi', 'mov', 'mkv']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.warning("Video processing is highly complex and not fully implemented in this example. "
|
|
|
"You'd typically extract frames/audio and process them separately.")
|
|
|
raise NotImplementedError("Full video processing is beyond this generalized example.")
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported file type: .{file_extension}. Or specify task for this file.")
|
|
|
else:
|
|
|
|
|
|
if not task: task = "text-classification"
|
|
|
if task not in self.default_models and not kwargs.get('model'):
|
|
|
logging.warning(
|
|
|
f"No default model for task '{task}'. Using default model for text-classification if available.")
|
|
|
if task not in ["text-classification", "sentiment-analysis", "summarization", "text-generation",
|
|
|
"translation_en_to_fr"]:
|
|
|
raise ValueError(
|
|
|
f"Unknown text task: '{task}'. Please choose from supported text tasks or provide a model_name.")
|
|
|
return self.process_text(input_data, task=task, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
dispatcher = AIToolDispatcher()
|
|
|
|
|
|
|
|
|
print("\n--- Text Examples ---")
|
|
|
text_input = "The new movie was absolutely fantastic! I highly recommend it."
|
|
|
print(f"Input: '{text_input}'")
|
|
|
print(f"Sentiment: {dispatcher.process_input(text_input, task='sentiment-analysis')}")
|
|
|
|
|
|
text_to_summarize = (
|
|
|
"Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—demonstrated by machines, as opposed to intelligence displayed by animals or humans. Example tasks in which AI is used include speech recognition, computer vision, translation, and others. AI applications include advanced web search engines, recommendation systems, understanding human speech, self-driving cars, and competing at the highest level in strategic game systems.")
|
|
|
print(f"\nInput: '{text_to_summarize}'")
|
|
|
print(f"Summary: {dispatcher.process_input(text_to_summarize, task='summarization', max_length=50, min_length=10)}")
|
|
|
|
|
|
text_to_generate = "In a galaxy far, far away, a brave knight"
|
|
|
print(f"\nInput: '{text_to_generate}'")
|
|
|
generated_text = dispatcher.process_input(text_to_generate, task='text-generation', max_new_tokens=50,
|
|
|
num_return_sequences=1)
|
|
|
print(f"Generated Text: {generated_text[0]['generated_text']}")
|
|
|
|
|
|
|
|
|
print("\n--- Image Examples ---")
|
|
|
|
|
|
|
|
|
|
|
|
dummy_image_path_1 = "dummy_cat_image.jpg"
|
|
|
dummy_image_path_2 = "dummy_building_image.png"
|
|
|
|
|
|
if not os.path.exists(dummy_image_path_1):
|
|
|
print(f"Creating a dummy image file at {dummy_image_path_1} for demonstration.")
|
|
|
try:
|
|
|
|
|
|
img = Image.new('RGB', (60, 30), color='red')
|
|
|
img.save(dummy_image_path_1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
print(
|
|
|
"Pillow not installed. Skipping dummy image creation. Please install Pillow and provide real image paths.")
|
|
|
dummy_image_path_1 = None
|
|
|
except Exception as e:
|
|
|
print(f"Could not create dummy image: {e}. Skipping image examples.")
|
|
|
dummy_image_path_1 = None
|
|
|
|
|
|
if dummy_image_path_1 and os.path.exists(dummy_image_path_1):
|
|
|
print(f"\nImage Input: {dummy_image_path_1}")
|
|
|
try:
|
|
|
print(f"Image Classification: {dispatcher.process_input(dummy_image_path_1, task='image-classification')}")
|
|
|
except Exception as e:
|
|
|
print(f"Error during image classification: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Audio Examples ---")
|
|
|
|
|
|
|
|
|
|
|
|
dummy_audio_path = "dummy_audio.wav"
|
|
|
|
|
|
if not os.path.exists(dummy_audio_path):
|
|
|
print(
|
|
|
f"Creating a dummy audio file at {dummy_audio_path} for demonstration (requires pydub, soundfile, ffmpeg).")
|
|
|
try:
|
|
|
from pydub.generators import Sine
|
|
|
|
|
|
sine_wave = Sine(440).to_audio_segment(duration=1000)
|
|
|
sine_wave.export(dummy_audio_path, format="wav")
|
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
print(
|
|
|
"pydub not installed. Skipping dummy audio creation. Please install pydub, soundfile, and ffmpeg (system-wide) and provide real audio paths.")
|
|
|
dummy_audio_path = None
|
|
|
except Exception as e:
|
|
|
print(f"Could not create dummy audio: {e}. Skipping audio examples.")
|
|
|
dummy_audio_path = None
|
|
|
|
|
|
if dummy_audio_path and os.path.exists(dummy_audio_path):
|
|
|
print(f"\nAudio Input: {dummy_audio_path}")
|
|
|
try:
|
|
|
|
|
|
|
|
|
transcription = dispatcher.process_input(dummy_audio_path, task='automatic-speech-recognition')
|
|
|
print(f"Audio Transcription: {transcription['text']}")
|
|
|
except NotImplementedError:
|
|
|
print("Skipping audio due to unimplemented feature or missing dependencies.")
|
|
|
except Exception as e:
|
|
|
print(f"Error during audio transcription: {e}")
|
|
|
|
|
|
|
|
|
if os.path.exists(dummy_image_path_1) and "dummy_cat_image.jpg" in dummy_image_path_1:
|
|
|
os.remove(dummy_image_path_1)
|
|
|
if os.path.exists(dummy_audio_path) and "dummy_audio.wav" in dummy_audio_path:
|
|
|
os.remove(dummy_audio_path)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
MODEL_CHECKPOINT = "openai/whisper-tiny.en"
|
|
|
|
|
|
LANGUAGE_ABBREVIATION = "en"
|
|
|
DATASET_NAME = "common_voice"
|
|
|
DATASET_CONFIG = "ja"
|
|
|
SPLIT_TRAIN = "train"
|
|
|
SPLIT_VALID = "validation"
|
|
|
BATCH_SIZE = 16
|
|
|
GRADIENT_ACCUMULATION_STEPS = 1
|
|
|
LEARNING_RATE = 1e-5
|
|
|
NUM_TRAIN_EPOCHS = 3
|
|
|
OUTPUT_DIR = "./whisper-tiny-finetuned"
|
|
|
MAX_INPUT_LENGTH_IN_S = 30.0
|
|
|
|
|
|
|
|
|
|
|
|
def load_asr_components(model_checkpoint: str):
|
|
|
logging.info(f"Loading feature extractor, tokenizer, and model for {model_checkpoint}...")
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, language=LANGUAGE_ABBREVIATION, task="transcribe")
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_checkpoint)
|
|
|
|
|
|
|
|
|
model.config.forced_decoder_ids = None
|
|
|
model.config.suppress_tokens = []
|
|
|
|
|
|
|
|
|
model.config.decoder_start_token_id = tokenizer.bos_token_id
|
|
|
model.config.pad_token_id = tokenizer.pad_token_id
|
|
|
|
|
|
logging.info("Components loaded.")
|
|
|
return feature_extractor, tokenizer, model
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataset(
|
|
|
dataset_name: str,
|
|
|
dataset_config: str,
|
|
|
split: str,
|
|
|
feature_extractor,
|
|
|
tokenizer,
|
|
|
max_input_length_in_s: float
|
|
|
):
|
|
|
logging.info(f"Loading dataset '{dataset_name}' with config '{dataset_config}' split '{split}'...")
|
|
|
|
|
|
dataset = load_dataset(dataset_name, dataset_config, split=split, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
|
|
|
|
|
|
|
|
|
dataset = dataset.filter(
|
|
|
lambda example: example["audio"]["array"].shape[0] < max_input_length_in_s * feature_extractor.sampling_rate,
|
|
|
num_proc=os.cpu_count()
|
|
|
)
|
|
|
|
|
|
logging.info(f"Dataset loaded and casted. Number of examples: {len(dataset)}")
|
|
|
|
|
|
|
|
|
def prepare_example(example):
|
|
|
|
|
|
audio = example["audio"]
|
|
|
|
|
|
|
|
|
input_features = feature_extractor(
|
|
|
audio["array"], sampling_rate=audio["sampling_rate"]
|
|
|
).input_features[0]
|
|
|
|
|
|
|
|
|
|
|
|
sentence = example["sentence"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_normalized = re.sub(r"[^\w\s]", "", sentence).lower()
|
|
|
|
|
|
labels = tokenizer(text_normalized).input_ids
|
|
|
|
|
|
return {"input_features": input_features, "labels": labels}
|
|
|
|
|
|
|
|
|
logging.info("Applying preprocessing to dataset...")
|
|
|
|
|
|
processed_dataset = dataset.map(
|
|
|
prepare_example,
|
|
|
remove_columns=dataset.column_names,
|
|
|
num_proc=os.cpu_count() if os.cpu_count() else 1
|
|
|
)
|
|
|
logging.info("Preprocessing complete.")
|
|
|
return processed_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataCollatorSpeechSeq2SeqWithPadding(DataCollatorForSeq2Seq):
|
|
|
def __call__(self, features):
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
|
|
batch = self.tokenizer.feature_extractor.pad(input_features, return_tensors="pt")
|
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
labels_batch = self.tokenizer.pad(label_features, return_tensors="pt")
|
|
|
|
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
|
|
|
|
|
batch["labels"] = labels
|
|
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics(pred, tokenizer):
|
|
|
metric = evaluate.load("wer")
|
|
|
pred_ids = pred.predictions
|
|
|
label_ids = pred.label_ids
|
|
|
|
|
|
|
|
|
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
|
|
|
|
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
|
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
|
|
|
|
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
|
|
return {"wer": wer}
|
|
|
|
|
|
|
|
|
|
|
|
def fine_tune_asr_model():
|
|
|
feature_extractor, tokenizer, model = load_asr_components(MODEL_CHECKPOINT)
|
|
|
if None in [feature_extractor, tokenizer, model]:
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
train_dataset = prepare_dataset(
|
|
|
DATASET_NAME, DATASET_CONFIG, SPLIT_TRAIN, feature_extractor, tokenizer, MAX_INPUT_LENGTH_IN_S
|
|
|
).select(range(500))
|
|
|
|
|
|
eval_dataset = prepare_dataset(
|
|
|
DATASET_NAME, DATASET_CONFIG, SPLIT_VALID, feature_extractor, tokenizer, MAX_INPUT_LENGTH_IN_S
|
|
|
).select(range(100))
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error preparing dataset: {e}")
|
|
|
logging.info("Make sure the dataset name/config is correct and you have internet access.")
|
|
|
logging.info("Also ensure the 'audio' and 'sentence' columns exist in your chosen dataset.")
|
|
|
return
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments(
|
|
|
output_dir=OUTPUT_DIR,
|
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
|
|
learning_rate=LEARNING_RATE,
|
|
|
num_train_epochs=NUM_TRAIN_EPOCHS,
|
|
|
predict_with_generate=True,
|
|
|
generation_max_length=MAX_INPUT_LENGTH_IN_S,
|
|
|
fp16=torch.cuda.is_available(),
|
|
|
save_steps=500,
|
|
|
eval_steps=500,
|
|
|
logging_steps=25,
|
|
|
report_to=["tensorboard"],
|
|
|
load_best_model_at_end=True,
|
|
|
metric_for_best_model="wer",
|
|
|
greater_is_better=False,
|
|
|
push_to_hub=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=eval_dataset,
|
|
|
tokenizer=tokenizer,
|
|
|
feature_extractor=feature_extractor,
|
|
|
data_collator=data_collator,
|
|
|
compute_metrics=lambda p: compute_metrics(p, tokenizer),
|
|
|
)
|
|
|
|
|
|
|
|
|
logging.info("Starting model training...")
|
|
|
trainer.train()
|
|
|
logging.info("Training complete. Saving final model.")
|
|
|
|
|
|
|
|
|
trainer.save_model(OUTPUT_DIR)
|
|
|
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
|
feature_extractor.save_pretrained(OUTPUT_DIR)
|
|
|
logging.info(f"Fine-tuned ASR model saved to {OUTPUT_DIR}")
|
|
|
|
|
|
logging.info("\n--- How to use the fine-tuned model ---")
|
|
|
logging.info(f"from transformers import pipeline")
|
|
|
logging.info(
|
|
|
f"transcriber = pipeline('automatic-speech-recognition', model='{OUTPUT_DIR}', tokenizer='{OUTPUT_DIR}', feature_extractor='{OUTPUT_DIR}')")
|
|
|
logging.info(f"print(transcriber('path/to/your/audio.wav'))")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
fine_tune_asr_model()
|
|
|
|
|
|
|
|
|
_pipeline_cache = {}
|
|
|
|
|
|
|
|
|
def get_pipeline(task, model_name=None, **kwargs):
|
|
|
key = f"{task}-{model_name}-{hash(frozenset(kwargs.items()))}"
|
|
|
if key not in _pipeline_cache:
|
|
|
_pipeline_cache[key] = pipeline(task, model=model_name, **kwargs) if model_name else pipeline(task, **kwargs)
|
|
|
return _pipeline_cache[key]
|
|
|
|
|
|
|
|
|
|
|
|
def_models = {
|
|
|
"sentiment-analysis": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
"summarization": "sshleifer/distilbart-cnn-12-6",
|
|
|
"text-generation": "gpt2",
|
|
|
"translation_en_to_fr": "Helsinki-NLP/opus-mt-en-fr",
|
|
|
"image-classification": "google/vit-base-patch16-224",
|
|
|
"object-detection": "facebook/detr-resnet-50",
|
|
|
"automatic-speech-recognition": "openai/whisper-tiny.en"
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def build_rag_qa(pdf_path):
|
|
|
loader = PyPDFLoader(pdf_path)
|
|
|
docs = loader.load()
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
|
|
split_docs = splitter.split_documents(docs)
|
|
|
embeddings = HuggingFaceEmbeddings()
|
|
|
vectorstore = FAISS.from_documents(split_docs, embeddings)
|
|
|
return RetrievalQA.from_chain_type(llm=pipeline("text-generation"), retriever=vectorstore.as_retriever())
|
|
|
|
|
|
|
|
|
|
|
|
class AllInOneDispatcher:
|
|
|
def __init__(self):
|
|
|
self.memory = []
|
|
|
|
|
|
def _pipe(self, task, model=None):
|
|
|
return get_pipeline(task, model_name=model or def_models.get(task))
|
|
|
|
|
|
def _is_file(self, path):
|
|
|
return os.path.exists(path)
|
|
|
|
|
|
def handle_text(self, text, task="sentiment-analysis", **kwargs):
|
|
|
result = self._pipe(task)(text, **kwargs)
|
|
|
self.memory.append((task, text, result))
|
|
|
return result
|
|
|
|
|
|
def handle_image(self, path, task="image-classification", **kwargs):
|
|
|
image = Image.open(path)
|
|
|
result = self._pipe(task)(image, **kwargs)
|
|
|
self.memory.append((task, path, result))
|
|
|
return result
|
|
|
|
|
|
def handle_audio(self, path, task="automatic-speech-recognition", **kwargs):
|
|
|
audio = AudioSegment.from_file(path).set_channels(1).set_frame_rate(16000)
|
|
|
buf = io.BytesIO()
|
|
|
audio.export(buf, format="wav")
|
|
|
buf.seek(0)
|
|
|
data, sr = sf.read(buf)
|
|
|
if data.dtype != np.float32:
|
|
|
data = data.astype(np.float32)
|
|
|
result = self._pipe(task)(data.tolist(), sampling_rate=sr)
|
|
|
self.memory.append((task, path, result))
|
|
|
return result
|
|
|
|
|
|
def handle_video(self, path):
|
|
|
import cv2
|
|
|
frames, audio_file = [], "extracted_audio.wav"
|
|
|
cap = cv2.VideoCapture(path)
|
|
|
while cap.isOpened():
|
|
|
ret, frame = cap.read()
|
|
|
if not ret:
|
|
|
break
|
|
|
frames.append(Image.fromarray(frame))
|
|
|
if len(frames) >= 5: break
|
|
|
cap.release()
|
|
|
os.system(f"ffmpeg -i {path} -q:a 0 -map a {audio_file} -y")
|
|
|
image_result = self.handle_image(frames[0], task="image-classification")
|
|
|
audio_result = self.handle_audio(audio_file)
|
|
|
os.remove(audio_file)
|
|
|
return {"image": image_result, "audio": audio_result}
|
|
|
|
|
|
def handle_pdf(self, path):
|
|
|
qa = build_rag_qa(path)
|
|
|
return qa.run("Summarize this document")
|
|
|
|
|
|
def handle_tts(self, text, lang='en'):
|
|
|
tts = gTTS(text=text, lang=lang)
|
|
|
temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
|
|
|
tts.save(temp_path)
|
|
|
return temp_path
|
|
|
|
|
|
def process(self, input_data, task=None, **kwargs):
|
|
|
if self._is_file(input_data):
|
|
|
ext = input_data.split('.')[-1].lower()
|
|
|
if ext in ['jpg', 'jpeg', 'png']: return self.handle_image(input_data, task, **kwargs)
|
|
|
elif ext in ['mp3', 'wav']: return self.handle_audio(input_data, task, **kwargs)
|
|
|
elif ext in ['mp4', 'mov']: return self.handle_video(input_data)
|
|
|
elif ext in ['pdf']: return self.handle_pdf(input_data)
|
|
|
else: raise ValueError(f"Unsupported file type: {ext}")
|
|
|
else:
|
|
|
if task == "tts": return self.handle_tts(input_data)
|
|
|
return self.handle_text(input_data, task, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
ai = AllInOneDispatcher()
|
|
|
print("Text:", ai.process("The weather is great today!", task="sentiment-analysis"))
|
|
|
print("Summarize:", ai.process("Artificial intelligence is a broad field...", task="summarization"))
|
|
|
print("TTS path:", ai.process("This is a test speech.", task="tts"))
|
|
|
|