Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| # Load a pre-trained version of ClinicalGPT | |
| model = AutoModelForCausalLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| # Tokenize your clinical text data using the AutoTokenizer class | |
| tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| # Convert your tokenized data into PyTorch tensors and create a PyTorch Dataset object | |
| import torch | |
| from torch.utils.data import Dataset | |
| class ClinicalDataset(Dataset): | |
| def __init__(self, texts, labels, tokenizer): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = self.texts[idx] | |
| label = self.labels[idx] | |
| encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)} | |
| dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer) | |
| # Fine-tune the pre-trained model on your clinical dataset | |
| from transformers import Trainer, TrainingArguments | |
| training_args = TrainingArguments( | |
| output_dir='./results', # output directory | |
| num_train_epochs=3, # total number of training epochs | |
| per_device_train_batch_size=16, # batch size per device during training | |
| per_device_eval_batch_size=64, # batch size for evaluation | |
| warmup_steps=500, # number of warmup steps for learning rate scheduler | |
| weight_decay=0.01, # strength of weight decay | |
| logging_dir='./logs', # directory for storing logs | |
| logging_steps=10, ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]), | |
| 'attention_mask': torch.stack([f['attention_mask'] for f in data]), | |
| 'labels': torch.stack([f['labels'] for f in data])}, ) | |
| trainer.train() | |