Canstralian's picture
Update app.py
2b6ee92 verified
import gradio as gr
from transformers import pipeline, Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
# Initialize model and tokenizer
model_name = "huggingface/transformer_model" # Replace with the actual model name
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define Gradio interface function
def upload_and_finetune(file):
# Read the uploaded file (assuming it's a CSV for this example)
file_path = file.name
data = pd.read_csv(file_path) # Update this if the file format is different
# Preprocess the data (tokenization)
# This example assumes the dataset has a 'text' column that contains the training data.
texts = data['text'].tolist()
encodings = tokenizer(texts, truncation=True, padding=True, return_tensors="pt")
# Create a dataset and dataloader for training
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
return item
train_dataset = CustomDataset(encodings)
# Set up training arguments
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=4, # batch size for training
logging_dir='./logs', # directory for storing logs
)
# Set up Trainer
trainer = Trainer(
model=model, # the model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
)
# Train the model
trainer.train()
# Save the fine-tuned model
model.save_pretrained('./fine_tuned_model')
return f"File {file.name} uploaded and model fine-tuned successfully!"
# Create Gradio interface with correct parameter
interface = gr.Interface(
fn=upload_and_finetune,
inputs=[gr.File(label="Upload Dataset for Fine-Tuning", file_count="single", type="file")],
outputs="text"
)
if __name__ == "__main__":
interface.launch()