| from datasets import load_dataset
|
| from tokenizers import Tokenizer
|
| from torch.utils.data import DataLoader, Dataset
|
| import torch
|
| from transformer_chat import TransformerChatbot
|
| import pandas as pd
|
| import random
|
|
|
|
|
| raw_dataset = load_dataset("tuetschek/atis", split="train")
|
|
|
|
|
| tokenizer = Tokenizer.from_file('tokenizer.json')
|
|
|
|
|
| def create_response_for_intent(intent, text):
|
| """Create synthetic responses for ATIS intents"""
|
| responses = {
|
| 'atis_flight': [
|
| "I can help you with flight information. What specific details do you need?",
|
| "I'll search for flights matching your criteria. Please provide departure and arrival cities.",
|
| "Let me find available flights for you. When would you like to travel?"
|
| ],
|
| 'atis_flight_no': [
|
| "I can help you with flight number information. Please provide the flight number.",
|
| "Let me search for details about that flight number.",
|
| "I'll look up information for that specific flight."
|
| ],
|
| 'atis_airfare': [
|
| "I can help you find airfare information. What's your travel route?",
|
| "Let me search for the best airfare options for your trip.",
|
| "I'll check current airfare prices for your destination."
|
| ],
|
| 'atis_airline': [
|
| "I can help you with airline information. Which airline are you looking for?",
|
| "Let me provide information about that airline.",
|
| "I'll search for details about the airline you mentioned."
|
| ],
|
| 'atis_abbreviation': [
|
| "I can help you with airport abbreviations. Which abbreviation do you need?",
|
| "Let me explain that airport abbreviation for you.",
|
| "I'll provide the full name for that airport code."
|
| ],
|
| 'atis_airport': [
|
| "I can help you with airport information. Which airport are you looking for?",
|
| "Let me provide details about that airport.",
|
| "I'll search for information about the airport you mentioned."
|
| ],
|
| 'atis_distance': [
|
| "I can help you calculate distances between airports. Which airports are you interested in?",
|
| "Let me calculate the distance for you.",
|
| "I'll provide distance information between those locations."
|
| ],
|
| 'atis_ground_service': [
|
| "I can help you with ground transportation services. What type of service do you need?",
|
| "Let me find ground transportation options for you.",
|
| "I'll search for available ground services at your destination."
|
| ],
|
| 'atis_aircraft': [
|
| "I can help you with aircraft information. What type of aircraft are you looking for?",
|
| "Let me provide details about that aircraft type.",
|
| "I'll search for information about the aircraft you mentioned."
|
| ],
|
| 'atis_capacity': [
|
| "I can help you with capacity information. What specific capacity details do you need?",
|
| "Let me check the capacity for that flight or aircraft.",
|
| "I'll provide capacity information for your query."
|
| ],
|
| 'atis_quantity': [
|
| "I can help you with quantity information. What specific quantity are you looking for?",
|
| "Let me check the quantity for that item or service.",
|
| "I'll provide quantity information for your request."
|
| ],
|
| 'atis_meal': [
|
| "I can help you with meal information. What type of meal service are you looking for?",
|
| "Let me check meal options for your flight.",
|
| "I'll provide information about meal services available."
|
| ],
|
| 'atis_cheapest': [
|
| "I can help you find the cheapest options. What's your travel route?",
|
| "Let me search for the most affordable options for your trip.",
|
| "I'll find the cheapest flights or services for you."
|
| ],
|
| 'atis_restriction': [
|
| "I can help you with travel restrictions. What type of restrictions are you asking about?",
|
| "Let me check the restrictions for your travel plans.",
|
| "I'll provide information about travel restrictions."
|
| ],
|
| 'atis_day_name': [
|
| "I can help you with day information. What specific day are you looking for?",
|
| "Let me check the schedule for that day.",
|
| "I'll provide information about flights or services on that day."
|
| ]
|
| }
|
|
|
|
|
| base_responses = responses.get(intent, [
|
| "I can help you with that. Please provide more details.",
|
| "Let me assist you with your request.",
|
| "I'll help you find the information you need."
|
| ])
|
|
|
|
|
| if "flight" in text.lower():
|
| base_responses.extend([
|
| "I can help you book a flight. What are your travel dates?",
|
| "Let me search for available flights for you.",
|
| "I'll help you find the best flight options."
|
| ])
|
|
|
| return random.choice(base_responses)
|
|
|
|
|
| def create_training_pairs():
|
| training_data = []
|
|
|
| for item in raw_dataset:
|
| question = item['text']
|
| intent = item['intent']
|
| response = create_response_for_intent(intent, question)
|
|
|
|
|
| question_encoding = tokenizer.encode(question)
|
| response_encoding = tokenizer.encode(response)
|
|
|
|
|
| question_ids = [tokenizer.token_to_id("[CLS]")] + question_encoding.ids + [tokenizer.token_to_id("[SEP]")]
|
| response_ids = [tokenizer.token_to_id("[CLS]")] + response_encoding.ids + [tokenizer.token_to_id("[SEP]")]
|
|
|
| training_data.append({
|
| 'question_ids': question_ids,
|
| 'response_ids': response_ids,
|
| 'question_len': len(question_ids),
|
| 'response_len': len(response_ids)
|
| })
|
|
|
| return training_data
|
|
|
|
|
| class AtisGenerationDataset(Dataset):
|
| def __init__(self, training_data, tokenizer, max_length=128):
|
| self.training_data = training_data
|
| self.tokenizer = tokenizer
|
| self.max_length = max_length
|
|
|
| def __len__(self):
|
| return len(self.training_data)
|
|
|
| def __getitem__(self, idx):
|
| item = self.training_data[idx]
|
|
|
|
|
| question_ids = item['question_ids'][:self.max_length//2]
|
| response_ids = item['response_ids'][:self.max_length//2]
|
|
|
|
|
| question_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(question_ids))
|
| response_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(response_ids))
|
|
|
| return (
|
| torch.tensor(question_ids),
|
| torch.tensor(response_ids),
|
| torch.tensor(item['question_len']),
|
| torch.tensor(item['response_len'])
|
| )
|
|
|
|
|
| print("Creating training data...")
|
| training_data = create_training_pairs()
|
| print(f"Created {len(training_data)} training pairs")
|
|
|
|
|
| atis_dataset = AtisGenerationDataset(training_data, tokenizer)
|
| dataloader = DataLoader(atis_dataset, batch_size=16, shuffle=True)
|
|
|
|
|
| vocab_size = tokenizer.get_vocab_size()
|
| model = TransformerChatbot(
|
| vocab_size=vocab_size,
|
| d_model=512,
|
| num_heads=8,
|
| d_ff=2048,
|
| num_encoder_layers=6,
|
| num_decoder_layers=6,
|
| num_roles=2,
|
| max_turns=16,
|
| num_slots=len(set(item['intent'] for item in raw_dataset)),
|
| dropout=0.1
|
| )
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| model.to(device)
|
|
|
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
| loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[PAD]"))
|
|
|
| print("Starting training...")
|
| for epoch in range(10):
|
| model.train()
|
| total_loss = 0
|
| for batch_idx, (question_ids, response_ids, question_lens, response_lens) in enumerate(dataloader):
|
| question_ids = question_ids.to(device)
|
| response_ids = response_ids.to(device)
|
|
|
| batch_size, seq_len = question_ids.shape
|
|
|
|
|
| roles = torch.zeros_like(question_ids)
|
| turns = torch.zeros_like(question_ids)
|
|
|
| gen_logits, slot_logits = model(
|
| question_ids, response_ids,
|
| roles, roles,
|
| turns, turns
|
| )
|
|
|
|
|
| target_ids = response_ids[:, 1:]
|
| gen_logits = gen_logits[:, :-1, :]
|
|
|
| gen_logits_flat = gen_logits.reshape(-1, vocab_size)
|
| target_ids_flat = target_ids.reshape(-1)
|
| loss = loss_fn(gen_logits_flat, target_ids_flat)
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
| total_loss += loss.item()
|
| if batch_idx % 100 == 0:
|
| print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
|
|
|
| avg_loss = total_loss / len(dataloader)
|
| print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")
|
|
|
|
|
| print("Saving model...")
|
| torch.save(model.state_dict(), 'atis_transformer.pt')
|
| print("Training completed!") |