#!/usr/bin/env python3 import os import json import logging from typing import List, Dict, Tuple from datasets import load_dataset, Dataset from tqdm import tqdm # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class AgriQADatasetPreparator: def __init__(self, output_dir: str = "data", dataset_name: str = "shchoi83/agriQA"): self.output_dir = output_dir self.dataset_name = dataset_name self.dataset = None # Create output directory os.makedirs(self.output_dir, exist_ok=True) def download_dataset(self) -> None: """Download the agriQA dataset from Hugging Face.""" logger.info(f"Downloading dataset: {self.dataset_name}") self.dataset = load_dataset(self.dataset_name) logger.info(f"Dataset downloaded successfully. Train samples: {len(self.dataset['train'])}") def preprocess_for_chat(self, question: str, answer: str) -> str: """Format question-answer pair for chat model training.""" # Use the chat template format for Qwen models formatted_text = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>" return formatted_text def clean_text(self, text: str) -> str: """Clean and normalize text.""" if not text: return "" # Basic cleaning text = text.strip() text = text.replace('\n', ' ').replace('\r', ' ') text = ' '.join(text.split()) # Remove extra whitespace return text def filter_qa_pairs(self, question: str, answer: str) -> bool: """Filter out low-quality question-answer pairs.""" # Remove pairs with very short or very long responses if len(answer) < 10 or len(answer) > 2000: return False # Remove pairs with very short questions if len(question) < 5: return False # More lenient filtering for agricultural content # Allow some non-ASCII characters that might be in agricultural terms # but filter out completely non-English content english_chars = sum(1 for c in question + answer if c.isascii()) total_chars = len(question + answer) if total_chars > 0 and english_chars / total_chars < 0.7: return False return True def prepare_training_data(self) -> Tuple[List[str], List[Dict]]: logger.info("Preparing training data...") formatted_data = [] raw_data = [] for item in tqdm(self.dataset['train'], desc="Processing samples"): question = self.clean_text(item['questions']) answer = self.clean_text(item['answers']) # Filter quality if not self.filter_qa_pairs(question, answer): continue # Format for training formatted_text = self.preprocess_for_chat(question, answer) formatted_data.append(formatted_text) # Keep raw data for analysis raw_data.append({ 'question': question, 'answer': answer, 'text': item.get('text', '') }) logger.info(f"Prepared {len(formatted_data)} training samples") return formatted_data, raw_data def save_data(self, formatted_data: List[str], raw_data: List[Dict]) -> None: # Save formatted training data train_file = os.path.join(self.output_dir, "train_data.txt") with open(train_file, 'w', encoding='utf-8') as f: for item in formatted_data: f.write(item + '\n') # Save raw data for analysis raw_file = os.path.join(self.output_dir, "raw_data.json") with open(raw_file, 'w', encoding='utf-8') as f: json.dump(raw_data, f, ensure_ascii=False, indent=2) # Save statistics stats = { 'total_samples': len(formatted_data), 'avg_question_length': sum(len(item['question']) for item in raw_data) / len(raw_data), 'avg_answer_length': sum(len(item['answer']) for item in raw_data) / len(raw_data), 'dataset_info': { 'source': self.dataset_name, 'original_size': len(self.dataset['train']) } } stats_file = os.path.join(self.output_dir, "dataset_stats.json") with open(stats_file, 'w', encoding='utf-8') as f: json.dump(stats, f, indent=2) logger.info(f"Data saved to {self.output_dir}") logger.info(f"Training samples: {len(formatted_data)}") logger.info(f"Average question length: {stats['avg_question_length']:.1f} chars") logger.info(f"Average answer length: {stats['avg_answer_length']:.1f} chars") def create_validation_split(self, train_data: List[str], val_ratio: float = 0.1) -> Tuple[List[str], List[str]]: """Create validation split from training data.""" val_size = int(len(train_data) * val_ratio) val_data = train_data[:val_size] train_data_final = train_data[val_size:] # Save validation data val_file = os.path.join(self.output_dir, "val_data.txt") with open(val_file, 'w', encoding='utf-8') as f: for item in val_data: f.write(item + '\n') logger.info(f"Created validation split: {len(val_data)} samples") return train_data_final, val_data def run(self) -> None: logger.info("Starting dataset preparation...") # Check if preprocessed data already exists train_file = os.path.join(self.output_dir, "train_data.txt") val_file = os.path.join(self.output_dir, "val_data.txt") if os.path.exists(train_file) and os.path.exists(val_file): logger.info("Preprocessed data already exists. Skipping data preparation.") # Count lines in files with open(train_file, 'r', encoding='utf-8') as f: train_count = sum(1 for line in f if line.strip()) with open(val_file, 'r', encoding='utf-8') as f: val_count = sum(1 for line in f if line.strip()) logger.info(f"Training samples: {train_count}") logger.info(f"Validation samples: {val_count}") return # Download and load dataset logger.info("Downloading agriQA dataset...") self.download_dataset() # Prepare training data logger.info("Preparing training data...") formatted_data, raw_data = self.prepare_training_data() # Create validation split logger.info("Creating validation split...") train_data_final, val_data = self.create_validation_split(formatted_data) # Save all data logger.info("Saving preprocessed data...") self.save_data(train_data_final, raw_data) logger.info("Dataset preparation completed successfully!") logger.info("Next step: Run tokenization with 'python src/data/tokenize_dataset.py'") def main(): preparator = AgriQADatasetPreparator() preparator.run() if __name__ == "__main__": main()