Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| from datasets import load_dataset | |
| from sklearn.utils import resample | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq | |
| from torch.utils.data import Dataset | |
| import gradio as gr | |
| # Step 1: Load the dataset from Hugging Face (Customer Support dataset) | |
| dataset = load_dataset("bitext/Bitext-customer-support-llm-chatbot-training-dataset") | |
| # Step 2: Sample a subset (20% of the dataset for testing) | |
| sampled_data = dataset["train"].shuffle(seed=42).select([i for i in range(int(len(dataset["train"]) * 0.2))]) | |
| # Convert to DataFrame and display some rows | |
| sampled_data_df = pd.DataFrame(sampled_data) | |
| df_limited = sampled_data_df[['instruction', 'response']] | |
| # Step 3: Handle class imbalance using oversampling | |
| df_majority = df_limited[df_limited['response'] == df_limited['response'].mode()[0]] | |
| df_minority = df_limited[df_limited['response'] != df_limited['response'].mode()[0]] | |
| df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42) | |
| df_balanced = pd.concat([df_majority, df_minority_upsampled]) | |
| # Step 4: Load the pre-trained DialoGPT model and tokenizer | |
| model_name = "microsoft/DialoGPT-medium" | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Check if pad_token is None, and set it to eos_token if it is | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Step 5: Preprocess the data for training | |
| def preprocess_data_for_training(df, max_length=512): | |
| inputs = tokenizer(df['instruction'].tolist(), padding=True, truncation=True, max_length=max_length, return_tensors="pt") | |
| targets = tokenizer(df['response'].tolist(), padding=True, truncation=True, max_length=max_length, return_tensors="pt") | |
| input_ids = inputs['input_ids'] | |
| target_ids = targets['input_ids'] | |
| if input_ids.shape[1] != target_ids.shape[1]: | |
| target_ids = target_ids[:, :input_ids.shape[1]] | |
| target_ids = target_ids.roll(1, dims=1) | |
| target_ids[:, 0] = tokenizer.pad_token_id | |
| return {'input_ids': input_ids, 'attention_mask': inputs['attention_mask'], 'labels': target_ids} | |
| preprocessed_data = preprocess_data_for_training(df_balanced) | |
| # Step 6: Create a custom dataset class for fine-tuning | |
| class ChatbotDataset(Dataset): | |
| def __init__(self, inputs, targets): | |
| self.inputs = inputs | |
| self.targets = targets | |
| def __len__(self): | |
| return len(self.inputs['input_ids']) | |
| def __getitem__(self, idx): | |
| return { | |
| 'input_ids': self.inputs['input_ids'][idx], | |
| 'attention_mask': self.inputs['attention_mask'][idx], | |
| 'labels': self.targets['input_ids'][idx] | |
| } | |
| train_dataset = ChatbotDataset(preprocessed_data, preprocessed_data) | |
| # Step 7: Set up training arguments | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=3, | |
| per_device_train_batch_size=4, | |
| save_steps=10_000, | |
| save_total_limit=2, | |
| logging_dir='./logs', | |
| logging_steps=500, | |
| ) | |
| # Step 8: Initialize Trainer | |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator | |
| ) | |
| # Step 9: Fine-tune the model | |
| trainer.train() | |
| # Save the trained model and tokenizer | |
| model.save_pretrained("./trained_model") | |
| tokenizer.save_pretrained("./trained_model") | |
| # Optional: Test the chatbot after training | |
| def generate_response(input_text): | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| outputs = model.generate(inputs['input_ids'], max_length=50, pad_token_id=tokenizer.eos_token_id) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Gradio Interface | |
| def chatbot_interface(input_text): | |
| return generate_response(input_text) | |
| iface = gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", live=True) | |
| iface.launch() | |