Spaces:
Configuration error
Configuration error
| import pandas as pd | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification, AdamW | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder | |
| import pickle | |
| class TextDataset(Dataset): | |
| def __init__(self, texts, labels, tokenizer, max_len): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = self.texts[idx] | |
| label = self.labels[idx] | |
| encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt') | |
| return { | |
| 'input_ids': encoding['input_ids'].squeeze(0), | |
| 'attention_mask': encoding['attention_mask'].squeeze(0), | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| # Sample data | |
| data = pd.DataFrame({ | |
| "text": ["I love this", "I hate this", "This is amazing", "This is terrible"], | |
| "label": ["positive", "negative", "positive", "negative"] | |
| }) | |
| # Preprocess | |
| le = LabelEncoder() | |
| data["label_enc"] = le.fit_transform(data["label"]) | |
| train_texts, val_texts, train_labels, val_labels = train_test_split(data["text"], data["label_enc"], test_size=0.2) | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| train_dataset = TextDataset(train_texts.tolist(), train_labels.tolist(), tokenizer, max_len=32) | |
| train_loader = DataLoader(train_dataset, batch_size=2) | |
| model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | |
| optimizer = AdamW(model.parameters(), lr=5e-5) | |
| model.train() | |
| for epoch in range(1): | |
| for batch in train_loader: | |
| outputs = model(input_ids=batch['input_ids'], | |
| attention_mask=batch['attention_mask'], | |
| labels=batch['labels']) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| torch.save(model.state_dict(), "app/bert_model.pth") | |
| with open("app/tokenizer.pkl", "wb") as f: | |
| pickle.dump(tokenizer, f) | |
| with open("app/label_encoder.pkl", "wb") as f: | |
| pickle.dump(le, f) | |