AgriQA-assistant / src /data /prepare_dataset.py
othdu's picture
Upload 18 files
571aca4 verified
#!/usr/bin/env python3
import os
import json
import logging
from typing import List, Dict, Tuple
from datasets import load_dataset, Dataset
from tqdm import tqdm
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class AgriQADatasetPreparator:
def __init__(self, output_dir: str = "data", dataset_name: str = "shchoi83/agriQA"):
self.output_dir = output_dir
self.dataset_name = dataset_name
self.dataset = None
# Create output directory
os.makedirs(self.output_dir, exist_ok=True)
def download_dataset(self) -> None:
"""Download the agriQA dataset from Hugging Face."""
logger.info(f"Downloading dataset: {self.dataset_name}")
self.dataset = load_dataset(self.dataset_name)
logger.info(f"Dataset downloaded successfully. Train samples: {len(self.dataset['train'])}")
def preprocess_for_chat(self, question: str, answer: str) -> str:
"""Format question-answer pair for chat model training."""
# Use the chat template format for Qwen models
formatted_text = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>"
return formatted_text
def clean_text(self, text: str) -> str:
"""Clean and normalize text."""
if not text:
return ""
# Basic cleaning
text = text.strip()
text = text.replace('\n', ' ').replace('\r', ' ')
text = ' '.join(text.split()) # Remove extra whitespace
return text
def filter_qa_pairs(self, question: str, answer: str) -> bool:
"""Filter out low-quality question-answer pairs."""
# Remove pairs with very short or very long responses
if len(answer) < 10 or len(answer) > 2000:
return False
# Remove pairs with very short questions
if len(question) < 5:
return False
# More lenient filtering for agricultural content
# Allow some non-ASCII characters that might be in agricultural terms
# but filter out completely non-English content
english_chars = sum(1 for c in question + answer if c.isascii())
total_chars = len(question + answer)
if total_chars > 0 and english_chars / total_chars < 0.7:
return False
return True
def prepare_training_data(self) -> Tuple[List[str], List[Dict]]:
logger.info("Preparing training data...")
formatted_data = []
raw_data = []
for item in tqdm(self.dataset['train'], desc="Processing samples"):
question = self.clean_text(item['questions'])
answer = self.clean_text(item['answers'])
# Filter quality
if not self.filter_qa_pairs(question, answer):
continue
# Format for training
formatted_text = self.preprocess_for_chat(question, answer)
formatted_data.append(formatted_text)
# Keep raw data for analysis
raw_data.append({
'question': question,
'answer': answer,
'text': item.get('text', '')
})
logger.info(f"Prepared {len(formatted_data)} training samples")
return formatted_data, raw_data
def save_data(self, formatted_data: List[str], raw_data: List[Dict]) -> None:
# Save formatted training data
train_file = os.path.join(self.output_dir, "train_data.txt")
with open(train_file, 'w', encoding='utf-8') as f:
for item in formatted_data:
f.write(item + '\n')
# Save raw data for analysis
raw_file = os.path.join(self.output_dir, "raw_data.json")
with open(raw_file, 'w', encoding='utf-8') as f:
json.dump(raw_data, f, ensure_ascii=False, indent=2)
# Save statistics
stats = {
'total_samples': len(formatted_data),
'avg_question_length': sum(len(item['question']) for item in raw_data) / len(raw_data),
'avg_answer_length': sum(len(item['answer']) for item in raw_data) / len(raw_data),
'dataset_info': {
'source': self.dataset_name,
'original_size': len(self.dataset['train'])
}
}
stats_file = os.path.join(self.output_dir, "dataset_stats.json")
with open(stats_file, 'w', encoding='utf-8') as f:
json.dump(stats, f, indent=2)
logger.info(f"Data saved to {self.output_dir}")
logger.info(f"Training samples: {len(formatted_data)}")
logger.info(f"Average question length: {stats['avg_question_length']:.1f} chars")
logger.info(f"Average answer length: {stats['avg_answer_length']:.1f} chars")
def create_validation_split(self, train_data: List[str], val_ratio: float = 0.1) -> Tuple[List[str], List[str]]:
"""Create validation split from training data."""
val_size = int(len(train_data) * val_ratio)
val_data = train_data[:val_size]
train_data_final = train_data[val_size:]
# Save validation data
val_file = os.path.join(self.output_dir, "val_data.txt")
with open(val_file, 'w', encoding='utf-8') as f:
for item in val_data:
f.write(item + '\n')
logger.info(f"Created validation split: {len(val_data)} samples")
return train_data_final, val_data
def run(self) -> None:
logger.info("Starting dataset preparation...")
# Check if preprocessed data already exists
train_file = os.path.join(self.output_dir, "train_data.txt")
val_file = os.path.join(self.output_dir, "val_data.txt")
if os.path.exists(train_file) and os.path.exists(val_file):
logger.info("Preprocessed data already exists. Skipping data preparation.")
# Count lines in files
with open(train_file, 'r', encoding='utf-8') as f:
train_count = sum(1 for line in f if line.strip())
with open(val_file, 'r', encoding='utf-8') as f:
val_count = sum(1 for line in f if line.strip())
logger.info(f"Training samples: {train_count}")
logger.info(f"Validation samples: {val_count}")
return
# Download and load dataset
logger.info("Downloading agriQA dataset...")
self.download_dataset()
# Prepare training data
logger.info("Preparing training data...")
formatted_data, raw_data = self.prepare_training_data()
# Create validation split
logger.info("Creating validation split...")
train_data_final, val_data = self.create_validation_split(formatted_data)
# Save all data
logger.info("Saving preprocessed data...")
self.save_data(train_data_final, raw_data)
logger.info("Dataset preparation completed successfully!")
logger.info("Next step: Run tokenization with 'python src/data/tokenize_dataset.py'")
def main():
preparator = AgriQADatasetPreparator()
preparator.run()
if __name__ == "__main__":
main()