| | import os |
| | import re |
| | import json |
| | from pathlib import Path |
| | from typing import List, Dict, Optional, Any |
| | from dataclasses import dataclass, field |
| | from logger_config import config_logger |
| |
|
| | logger = config_logger(__name__) |
| |
|
| | @dataclass |
| | class TaskmasterDialogue: |
| | conversation_id: str |
| | instruction_id: Optional[str] |
| | scenario: Optional[str] |
| | domain: Optional[str] |
| | turns: List[Dict[str, Any]] |
| | original_metadata: Dict[str, Any] = field(default_factory=dict) |
| | |
| | def __str__(self): |
| | return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)" |
| | |
| | def validate(self) -> bool: |
| | return bool(self.conversation_id and isinstance(self.turns, list)) |
| | |
| | class RawDataProcessingConfig: |
| | """ |
| | Simple config for raw dataset processing |
| | """ |
| | def __init__( |
| | self, |
| | debug: bool = True, |
| | max_length: int = 512, |
| | min_turns: int = 4, |
| | min_user_words: int = 3 |
| | ): |
| | self.debug = debug |
| | self.max_length = max_length |
| | self.min_turns = min_turns |
| | self.min_user_words = min_user_words |
| |
|
| | class TaskmasterProcessor: |
| | """ |
| | Load Taskmaster-1 dialogues, extracts domain. |
| | Clean, filter, save to pipeline format. |
| | """ |
| | def __init__(self, config: RawDataProcessingConfig): |
| | self.config = config |
| | |
| | def load_taskmaster_dataset( |
| | self, |
| | base_dir: str, |
| | max_examples: Optional[int] = None |
| | ) -> List[TaskmasterDialogue]: |
| | """ |
| | Load & parse Taskmaster-1 JSON for self-dialogs & woz-dialogs. |
| | """ |
| | required_files = { |
| | "self-dialogs": "self-dialogs.json", |
| | "woz-dialogs": "woz-dialogs.json", |
| | "ontology": "ontology.json", |
| | } |
| | |
| | |
| | missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()] |
| | if missing: |
| | raise FileNotFoundError(f"Missing Taskmaster files: {missing}") |
| | |
| | |
| | ontology_path = Path(base_dir, required_files["ontology"]) |
| | with open(ontology_path, 'r', encoding='utf-8') as f: |
| | ontology = json.load(f) |
| | if self.config.debug: |
| | logger.info(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).") |
| | |
| | dialogues: List[TaskmasterDialogue] = [] |
| | |
| | |
| | file_keys = ["self-dialogs", "woz-dialogs"] |
| | for file_key in file_keys: |
| | file_path = Path(base_dir, required_files[file_key]) |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | raw_data = json.load(f) |
| | |
| | for d in raw_data: |
| | conversation_id = d.get("conversation_id", "") |
| | instruction_id = d.get("instruction_id", None) |
| | scenario_text = d.get("scenario", "") |
| | |
| | |
| | utterances = d.get("utterances", []) |
| | turns = self._process_utterances(utterances) |
| |
|
| | |
| | domain = self._extract_domain(scenario_text, turns) |
| |
|
| | |
| | new_dlg = TaskmasterDialogue( |
| | conversation_id=conversation_id, |
| | instruction_id=instruction_id, |
| | scenario=scenario_text, |
| | domain=domain, |
| | turns=turns, |
| | original_metadata={} |
| | ) |
| | dialogues.append(new_dlg) |
| | |
| | if max_examples and len(dialogues) >= max_examples: |
| | break |
| | |
| | if self.config.debug: |
| | logger.info(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.") |
| | return dialogues |
| | |
| | def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str: |
| | """ |
| | Combine scenario text + all turn texts to detect domain more robustly. |
| | """ |
| | combined_text = scenario.lower() |
| | for turn in turns: |
| | txt = turn.get('text', '').lower() |
| | combined_text += " " + txt |
| |
|
| | |
| | domain_patterns = { |
| | 'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat|hungry)\b', |
| | 'movie': r'\b(movie|cinema|film|ticket|showtime|theater|flick|screening)\b', |
| | 'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver)\b', |
| | 'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b', |
| | 'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice)\b', |
| | 'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b' |
| | } |
| |
|
| | for domain, pattern in domain_patterns.items(): |
| | if re.search(pattern, combined_text): |
| | |
| | if self.config.debug: |
| | logger.info(f"Matched domain: {domain} in scenario/turns") |
| | return domain |
| | |
| | if self.config.debug: |
| | logger.info("No domain match, returning 'other'") |
| | return 'other' |
| | |
| | def _clean_text(self, text: str) -> str: |
| | """ |
| | Simple text normalization |
| | """ |
| | |
| | text = re.sub(r'\s+', ' ', text) |
| | text = re.sub(r'([!?.,])\1+', r'\1', text) |
| | return text.strip() |
| |
|
| | def _is_numeric_line(self, text: str) -> bool: |
| | """ |
| | Return True if line is purely digits/punctuation/spaces, |
| | e.g. "4 3 13" and similar found in Taskmaster-1 dataset. |
| | """ |
| | pattern = r'^[\s]*[\d]+([\s\d.,]+)*[\s]*$' |
| | return bool(re.match(pattern, text)) |
| |
|
| | def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]: |
| | """ |
| | Filter out dialogues that don't meet min length requirements. Convert to pipeline format. |
| | { |
| | "dialogue_id": "...", |
| | "domain": "...", |
| | "turns": [ {"speaker": "user", "text": "..."}, ... ] |
| | } |
| | """ |
| | total = len(dialogues) |
| | invalid = 0 |
| | too_few_turns = 0 |
| | short_user_turns = 0 |
| | results = [] |
| | |
| | for dlg in dialogues: |
| | if not dlg.validate(): |
| | invalid += 1 |
| | continue |
| | |
| | |
| | if len(dlg.turns) < self.config.min_turns: |
| | too_few_turns += 1 |
| | continue |
| | |
| | |
| | keep = True |
| | for turn in dlg.turns: |
| | if turn['speaker'] == 'user': |
| | words_count = len(turn['text'].split()) |
| | if words_count < self.config.min_user_words: |
| | short_user_turns += 1 |
| | keep = False |
| | break |
| | |
| | if not keep: |
| | continue |
| | |
| | pipeline_dlg = { |
| | 'dialogue_id': dlg.conversation_id, |
| | 'domain': dlg.domain, |
| | 'turns': dlg.turns |
| | } |
| | results.append(pipeline_dlg) |
| | |
| | if self.config.debug: |
| | logger.info(f"\nFiltering Statistics:") |
| | logger.info(f"Total dialogues: {total}") |
| | logger.info(f"Invalid dialogues: {invalid}") |
| | logger.info(f"Too few turns: {too_few_turns}") |
| | logger.info(f"Short user turns: {short_user_turns}") |
| | logger.info(f"Remaining dialogues: {len(results)}") |
| | logger.info(f"Filtering rate: {((total - len(results)) / total) * 100:.1f}%\n") |
| | |
| | return results |
| | |
| | def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]: |
| | """Added logging to track utterance filtering""" |
| | total = len(utterances) |
| | empty = 0 |
| | numeric = 0 |
| | too_short = 0 |
| | cleaned_turns = [] |
| | |
| | for utt in utterances: |
| | speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user' |
| | raw_text = utt.get('text', '').strip() |
| | |
| | text = self._clean_text(raw_text) |
| |
|
| | if not text: |
| | empty += 1 |
| | continue |
| |
|
| | if self._is_numeric_line(text): |
| | numeric += 1 |
| | continue |
| |
|
| | if len(text.split()) < 3: |
| | too_short += 1 |
| | continue |
| |
|
| | cleaned_turns.append({ |
| | 'speaker': speaker, |
| | 'text': text |
| | }) |
| | |
| | if self.config.debug and total > 0: |
| | logger.info(f"\nUtterance Cleaning Statistics (Dialogue {utterances[0].get('conversation_id', 'unknown')}):") |
| | logger.info(f"Total utterances: {total}") |
| | logger.info(f"Empty/blank: {empty}") |
| | logger.info(f"Numeric only: {numeric}") |
| | logger.info(f"Too short (<3 words): {too_short}") |
| | logger.info(f"Remaining turns: {len(cleaned_turns)}") |
| | logger.info(f"Filtering rate: {((total - len(cleaned_turns)) / total) * 100:.1f}%\n") |
| | |
| | return cleaned_turns |
| |
|