Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os | |
| import json | |
| import logging | |
| from typing import List, Dict, Tuple | |
| from datasets import load_dataset, Dataset | |
| from tqdm import tqdm | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class AgriQADatasetPreparator: | |
| def __init__(self, output_dir: str = "data", dataset_name: str = "shchoi83/agriQA"): | |
| self.output_dir = output_dir | |
| self.dataset_name = dataset_name | |
| self.dataset = None | |
| # Create output directory | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| def download_dataset(self) -> None: | |
| """Download the agriQA dataset from Hugging Face.""" | |
| logger.info(f"Downloading dataset: {self.dataset_name}") | |
| self.dataset = load_dataset(self.dataset_name) | |
| logger.info(f"Dataset downloaded successfully. Train samples: {len(self.dataset['train'])}") | |
| def preprocess_for_chat(self, question: str, answer: str) -> str: | |
| """Format question-answer pair for chat model training.""" | |
| # Use the chat template format for Qwen models | |
| formatted_text = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>" | |
| return formatted_text | |
| def clean_text(self, text: str) -> str: | |
| """Clean and normalize text.""" | |
| if not text: | |
| return "" | |
| # Basic cleaning | |
| text = text.strip() | |
| text = text.replace('\n', ' ').replace('\r', ' ') | |
| text = ' '.join(text.split()) # Remove extra whitespace | |
| return text | |
| def filter_qa_pairs(self, question: str, answer: str) -> bool: | |
| """Filter out low-quality question-answer pairs.""" | |
| # Remove pairs with very short or very long responses | |
| if len(answer) < 10 or len(answer) > 2000: | |
| return False | |
| # Remove pairs with very short questions | |
| if len(question) < 5: | |
| return False | |
| # More lenient filtering for agricultural content | |
| # Allow some non-ASCII characters that might be in agricultural terms | |
| # but filter out completely non-English content | |
| english_chars = sum(1 for c in question + answer if c.isascii()) | |
| total_chars = len(question + answer) | |
| if total_chars > 0 and english_chars / total_chars < 0.7: | |
| return False | |
| return True | |
| def prepare_training_data(self) -> Tuple[List[str], List[Dict]]: | |
| logger.info("Preparing training data...") | |
| formatted_data = [] | |
| raw_data = [] | |
| for item in tqdm(self.dataset['train'], desc="Processing samples"): | |
| question = self.clean_text(item['questions']) | |
| answer = self.clean_text(item['answers']) | |
| # Filter quality | |
| if not self.filter_qa_pairs(question, answer): | |
| continue | |
| # Format for training | |
| formatted_text = self.preprocess_for_chat(question, answer) | |
| formatted_data.append(formatted_text) | |
| # Keep raw data for analysis | |
| raw_data.append({ | |
| 'question': question, | |
| 'answer': answer, | |
| 'text': item.get('text', '') | |
| }) | |
| logger.info(f"Prepared {len(formatted_data)} training samples") | |
| return formatted_data, raw_data | |
| def save_data(self, formatted_data: List[str], raw_data: List[Dict]) -> None: | |
| # Save formatted training data | |
| train_file = os.path.join(self.output_dir, "train_data.txt") | |
| with open(train_file, 'w', encoding='utf-8') as f: | |
| for item in formatted_data: | |
| f.write(item + '\n') | |
| # Save raw data for analysis | |
| raw_file = os.path.join(self.output_dir, "raw_data.json") | |
| with open(raw_file, 'w', encoding='utf-8') as f: | |
| json.dump(raw_data, f, ensure_ascii=False, indent=2) | |
| # Save statistics | |
| stats = { | |
| 'total_samples': len(formatted_data), | |
| 'avg_question_length': sum(len(item['question']) for item in raw_data) / len(raw_data), | |
| 'avg_answer_length': sum(len(item['answer']) for item in raw_data) / len(raw_data), | |
| 'dataset_info': { | |
| 'source': self.dataset_name, | |
| 'original_size': len(self.dataset['train']) | |
| } | |
| } | |
| stats_file = os.path.join(self.output_dir, "dataset_stats.json") | |
| with open(stats_file, 'w', encoding='utf-8') as f: | |
| json.dump(stats, f, indent=2) | |
| logger.info(f"Data saved to {self.output_dir}") | |
| logger.info(f"Training samples: {len(formatted_data)}") | |
| logger.info(f"Average question length: {stats['avg_question_length']:.1f} chars") | |
| logger.info(f"Average answer length: {stats['avg_answer_length']:.1f} chars") | |
| def create_validation_split(self, train_data: List[str], val_ratio: float = 0.1) -> Tuple[List[str], List[str]]: | |
| """Create validation split from training data.""" | |
| val_size = int(len(train_data) * val_ratio) | |
| val_data = train_data[:val_size] | |
| train_data_final = train_data[val_size:] | |
| # Save validation data | |
| val_file = os.path.join(self.output_dir, "val_data.txt") | |
| with open(val_file, 'w', encoding='utf-8') as f: | |
| for item in val_data: | |
| f.write(item + '\n') | |
| logger.info(f"Created validation split: {len(val_data)} samples") | |
| return train_data_final, val_data | |
| def run(self) -> None: | |
| logger.info("Starting dataset preparation...") | |
| # Check if preprocessed data already exists | |
| train_file = os.path.join(self.output_dir, "train_data.txt") | |
| val_file = os.path.join(self.output_dir, "val_data.txt") | |
| if os.path.exists(train_file) and os.path.exists(val_file): | |
| logger.info("Preprocessed data already exists. Skipping data preparation.") | |
| # Count lines in files | |
| with open(train_file, 'r', encoding='utf-8') as f: | |
| train_count = sum(1 for line in f if line.strip()) | |
| with open(val_file, 'r', encoding='utf-8') as f: | |
| val_count = sum(1 for line in f if line.strip()) | |
| logger.info(f"Training samples: {train_count}") | |
| logger.info(f"Validation samples: {val_count}") | |
| return | |
| # Download and load dataset | |
| logger.info("Downloading agriQA dataset...") | |
| self.download_dataset() | |
| # Prepare training data | |
| logger.info("Preparing training data...") | |
| formatted_data, raw_data = self.prepare_training_data() | |
| # Create validation split | |
| logger.info("Creating validation split...") | |
| train_data_final, val_data = self.create_validation_split(formatted_data) | |
| # Save all data | |
| logger.info("Saving preprocessed data...") | |
| self.save_data(train_data_final, raw_data) | |
| logger.info("Dataset preparation completed successfully!") | |
| logger.info("Next step: Run tokenization with 'python src/data/tokenize_dataset.py'") | |
| def main(): | |
| preparator = AgriQADatasetPreparator() | |
| preparator.run() | |
| if __name__ == "__main__": | |
| main() |