| | |
| | """ |
| | Data loader for Quillan multimodal training |
| | Handles JSONL fine-tuning data, song lyrics, and knowledge files |
| | """ |
| |
|
| | import json |
| | import os |
| | import glob |
| | import torch |
| | from typing import List, Dict, Any, Tuple |
| | import random |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | class QuillanDataset: |
| | def __init__(self, data_dir: str = "."): |
| | self.data_dir = data_dir |
| | self.samples = [] |
| | self.image_files = [] |
| | self.audio_files = [] |
| | self.video_files = [] |
| | self.tokenizer = None |
| | self.load_all_data() |
| | |
| | def set_tokenizer(self, tokenizer): |
| | """Set the tokenizer for proper text tokenization""" |
| | self.tokenizer = tokenizer |
| | |
| | def load_all_data(self): |
| | """Load all available datasets including multimedia""" |
| | print("🔄 Loading Quillan multimodal datasets...") |
| | |
| | |
| | self.load_jsonl_data() |
| | self.load_song_lyrics() |
| | self.load_knowledge_files() |
| | |
| | |
| | self.load_image_files() |
| | self.load_audio_files() |
| | self.load_video_files() |
| | |
| | print(f"✅ Loaded {len(self.samples)} text samples") |
| | print(f"✅ Loaded {len(self.image_files)} image files") |
| | print(f"✅ Loaded {len(self.audio_files)} audio files") |
| | print(f"✅ Loaded {len(self.video_files)} video files") |
| | print(f"✅ Total training samples: {len(self.samples) + len(self.image_files) + len(self.audio_files) + len(self.video_files)}") |
| | |
| | def load_image_files(self): |
| | """Load image files from Main images folder""" |
| | |
| | images_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "Main images") |
| | if os.path.exists(images_dir): |
| | print(f"🖼️ Loading images from {images_dir}") |
| | image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff'] |
| | for ext in image_extensions: |
| | pattern = os.path.join(images_dir, ext) |
| | self.image_files.extend(glob.glob(pattern)) |
| | print(f"✅ Found {len(self.image_files)} image files") |
| | else: |
| | print(f"⚠️ Main images directory not found at {images_dir}") |
| | |
| | def load_audio_files(self): |
| | """Load audio files from Mp3 files folder""" |
| | audio_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "Mp3 files") |
| | if os.path.exists(audio_dir): |
| | print(f"🎵 Loading audio from {audio_dir}") |
| | audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.aac', '*.ogg'] |
| | for ext in audio_extensions: |
| | pattern = os.path.join(audio_dir, ext) |
| | self.audio_files.extend(glob.glob(pattern)) |
| | print(f"✅ Found {len(self.audio_files)} audio files") |
| | else: |
| | print(f"⚠️ Mp3 files directory not found at {audio_dir}") |
| | |
| | def load_video_files(self): |
| | """Load video files from Lyric Videos folder""" |
| | video_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "Lyric Videos") |
| | if os.path.exists(video_dir): |
| | print(f"🎬 Loading videos from {video_dir}") |
| | video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv', '*.webm'] |
| | for ext in video_extensions: |
| | pattern = os.path.join(video_dir, ext) |
| | self.video_files.extend(glob.glob(pattern)) |
| | print(f"✅ Found {len(self.video_files)} video files") |
| | else: |
| | print(f"⚠️ Lyric Videos directory not found at {video_dir}") |
| | |
| | def load_jsonl_data(self): |
| | """Load the fine-tuning JSONL dataset - handles both old and new formats""" |
| | jsonl_path = os.path.join(self.data_dir, "Quillan_finetune_full_dataset.jsonl") |
| | if os.path.exists(jsonl_path): |
| | print(f"📚 Loading JSONL data from {jsonl_path}") |
| | with open(jsonl_path, 'r', encoding='utf-8') as f: |
| | for line_num, line in enumerate(f, 1): |
| | try: |
| | line = line.strip() |
| | if not line: |
| | continue |
| | data = json.loads(line) |
| | |
| | |
| | if 'text' in data: |
| | text = data['text'] |
| | if text and len(text.strip()) > 10: |
| | self.samples.append({ |
| | 'text': text, |
| | 'source': 'jsonl', |
| | 'line': line_num |
| | }) |
| | |
| | elif 'Output_Sections' in data and 'Final output' in data['Output_Sections']: |
| | text = data['Output_Sections']['Final output'] |
| | if text and len(text.strip()) > 10: |
| | self.samples.append({ |
| | 'text': text, |
| | 'source': 'jsonl', |
| | 'line': line_num |
| | }) |
| | |
| | except (json.JSONDecodeError, KeyError) as e: |
| | |
| | continue |
| | loaded_count = len([s for s in self.samples if s['source'] == 'jsonl']) |
| | if loaded_count > 0: |
| | print(f"✅ Loaded {loaded_count} JSONL samples") |
| | else: |
| | print("⚠️ No valid JSONL samples loaded, trying other sources...") |
| | |
| | def load_song_lyrics(self): |
| | """Load song lyrics from the Songs Lyrics directory in parent folder""" |
| | lyrics_dir = os.path.join(os.path.dirname(os.path.abspath(self.data_dir)), "Songs Lyrics") |
| | if os.path.exists(lyrics_dir): |
| | print(f"🎵 Loading song lyrics from {lyrics_dir}") |
| | lyrics_files = glob.glob(os.path.join(lyrics_dir, "*.md")) |
| | for file_path in lyrics_files: |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | if content and len(content.strip()) > 50: |
| | self.samples.append({ |
| | 'text': content, |
| | 'source': 'lyrics', |
| | 'file': os.path.basename(file_path) |
| | }) |
| | except Exception as e: |
| | print(f"⚠️ Error loading {file_path}: {e}") |
| | print(f"✅ Loaded {len([s for s in self.samples if s['source'] == 'lyrics'])} lyric files") |
| | |
| | def load_knowledge_files(self): |
| | """Load knowledge files from Quillan Knowledge files directory in parent folder""" |
| | knowledge_dir = os.path.join(os.path.dirname(os.path.abspath(self.data_dir)), "Quillan Knowledge files") |
| | if os.path.exists(knowledge_dir): |
| | print(f"🧠 Loading knowledge files from {knowledge_dir}") |
| | knowledge_files = glob.glob(os.path.join(knowledge_dir, "*.md")) |
| | for file_path in knowledge_files: |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | if content and len(content.strip()) > 100: |
| | self.samples.append({ |
| | 'text': content, |
| | 'source': 'knowledge', |
| | 'file': os.path.basename(file_path) |
| | }) |
| | except Exception as e: |
| | print(f"⚠️ Error loading {file_path}: {e}") |
| | print(f"✅ Loaded {len([s for s in self.samples if s['source'] == 'knowledge'])} knowledge files") |
| | |
| | def get_training_batch(self, batch_size: int = 4, seq_len: int = 512) -> Dict[str, torch.Tensor]: |
| | """Get a training batch with proper multimodal inputs using real data""" |
| | if len(self.samples) == 0: |
| | raise ValueError("No data loaded!") |
| | |
| | |
| | batch_samples = random.sample(self.samples, min(batch_size, len(self.samples))) |
| | |
| | |
| | texts = [sample['text'][:seq_len*4] for sample in batch_samples] |
| | |
| | |
| | if self.tokenizer: |
| | |
| | text_ids = [self.tokenizer.encode(text, max_length=seq_len) for text in texts] |
| | else: |
| | |
| | text_ids = [] |
| | for text in texts: |
| | tokens = [min(ord(c), 999) for c in text] |
| | if len(tokens) < seq_len: |
| | tokens.extend([0] * (seq_len - len(tokens))) |
| | else: |
| | tokens = tokens[:seq_len] |
| | text_ids.append(tokens) |
| | |
| | text_tensor = torch.tensor(text_ids, dtype=torch.long) |
| | |
| | |
| | image_tensor = self._load_batch_images(batch_size) |
| | audio_tensor = self._load_batch_audio(batch_size) |
| | video_tensor = self._load_batch_video(batch_size) |
| | |
| | return { |
| | 'text_tokens': text_tensor, |
| | 'image': image_tensor, |
| | 'audio': audio_tensor, |
| | 'video': video_tensor, |
| | 'raw_texts': texts |
| | } |
| | |
| | def _load_batch_images(self, batch_size: int) -> torch.Tensor: |
| | """Load a batch of real images with timeout protection, or generate random""" |
| | if len(self.image_files) == 0: |
| | return torch.randn(batch_size, 3, 256, 256) |
| | |
| | batch_images = [] |
| | for _ in range(batch_size): |
| | img_path = random.choice(self.image_files) |
| | try: |
| | |
| | if not os.path.exists(img_path) or os.path.getsize(img_path) < 100: |
| | batch_images.append(torch.randn(3, 256, 256)) |
| | continue |
| | |
| | |
| | img = Image.open(img_path).convert('RGB') |
| | img = img.resize((128, 128), Image.Resampling.BILINEAR) |
| | |
| | img_full = Image.new('RGB', (256, 256), (0, 0, 0)) |
| | img_full.paste(img, (64, 64)) |
| | |
| | img_tensor = torch.tensor(np.array(img_full), dtype=torch.float32).permute(2, 0, 1) / 255.0 |
| | batch_images.append(img_tensor) |
| | except Exception: |
| | |
| | batch_images.append(torch.randn(3, 256, 256)) |
| | |
| | return torch.stack(batch_images) |
| | |
| | def _load_batch_audio(self, batch_size: int) -> torch.Tensor: |
| | """Load a batch of real audio, or generate random if no audio files""" |
| | if len(self.audio_files) == 0: |
| | |
| | return torch.randn(batch_size, 1, 2048) |
| | |
| | batch_audio = [] |
| | for _ in range(batch_size): |
| | |
| | |
| | |
| | batch_audio.append(torch.randn(1, 2048)) |
| | |
| | return torch.stack(batch_audio) |
| | |
| | def _load_batch_video(self, batch_size: int) -> torch.Tensor: |
| | """Load a batch of real video frames, or generate random if no video files""" |
| | if len(self.video_files) == 0: |
| | |
| | return torch.randn(batch_size, 3, 8, 32, 32) |
| | |
| | batch_video = [] |
| | for _ in range(batch_size): |
| | |
| | |
| | |
| | batch_video.append(torch.randn(3, 8, 32, 32)) |
| | |
| | return torch.stack(batch_video) |
| | |
| | def get_dataset_stats(self): |
| | """Get statistics about the loaded dataset""" |
| | stats = { |
| | 'total_samples': len(self.samples), |
| | 'sources': {} |
| | } |
| | |
| | for sample in self.samples: |
| | source = sample['source'] |
| | if source not in stats['sources']: |
| | stats['sources'][source] = 0 |
| | stats['sources'][source] += 1 |
| | |
| | return stats |
| |
|
| | if __name__ == "__main__": |
| | |
| | dataset = QuillanDataset() |
| | stats = dataset.get_dataset_stats() |
| | |
| | print("\n📊 Dataset Statistics:") |
| | print(f"Total samples: {stats['total_samples']}") |
| | for source, count in stats['sources'].items(): |
| | print(f" {source}: {count}") |
| | |
| | |
| | batch = dataset.get_training_batch(batch_size=2) |
| | print(f"\n🎯 Batch shapes:") |
| | for key, tensor in batch.items(): |
| | if isinstance(tensor, torch.Tensor): |
| | print(f" {key}: {tensor.shape}") |
| | else: |
| | print(f" {key}: {type(tensor)}") |
| |
|