ShynBui commited on
Commit
070d008
·
verified ·
1 Parent(s): 07a2715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -8
app.py CHANGED
@@ -60,19 +60,133 @@ def train_batch(dataloader):
60
 
61
  return True, "Batch training completed."
62
 
63
- def train_step(file=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if file:
65
  load_data(file)
66
  print(global_data)
 
 
 
 
 
67
 
68
- start_idx = 0
69
  batch_size = 8
70
  total_samples = len(global_data)
71
 
72
  counting = 0
73
  while start_idx < total_samples:
74
  print("Step:", counting)
75
- print("Percent:", (start_idx) / total_samples* 100, "%")
76
  counting += 1
77
  end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
78
  dataloader = get_dataloader(start_idx, end_idx, batch_size)
@@ -80,24 +194,25 @@ def train_step(file=None):
80
  try:
81
  success, message = train_batch(dataloader)
82
  if not success:
83
- return message
84
 
85
  except HTMLError as e:
86
  print("Exceeded GPU quota, retrying in 10 seconds...")
87
  time.sleep(10)
88
- continue
89
 
90
  start_idx = end_idx
91
-
92
  if not os.path.exists('./checkpoint'):
93
  os.makedirs('./checkpoint')
94
  torch.save(model.state_dict(), "./checkpoint/model.pt")
95
- return "Training completed and model saved."
 
96
 
97
  if __name__ == "__main__":
98
  iface = gr.Interface(
99
  fn=train_step,
100
- inputs=gr.File(label="Upload CSV"),
101
  outputs="text"
102
  )
103
  iface.launch()
 
60
 
61
  return True, "Batch training completed."
62
 
63
+ Hugging Face's logo
64
+ Hugging Face
65
+ Search models, datasets, users...
66
+ Models
67
+ Datasets
68
+ Spaces
69
+ Posts
70
+ Docs
71
+ Solutions
72
+ Pricing
73
+
74
+
75
+
76
+ Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
77
+ Spaces:
78
+
79
+ ShynBui
80
+ /
81
+ train_for_fun
82
+
83
+ private
84
+
85
+ Logs
86
+ App
87
+ Files
88
+ Community
89
+ Settings
90
+ train_for_fun
91
+ /
92
+ app.py
93
+
94
+ ShynBui's picture
95
+ ShynBui
96
+ Update app.py
97
+ 07a2715
98
+ verified
99
+ 15 minutes ago
100
+ raw
101
+
102
+ Copy download link
103
+ history
104
+ blame
105
+ edit
106
+ delete
107
+ No virus
108
+
109
+ 3.25 kB
110
+ import time
111
+ import torch
112
+ from transformers import BertForSequenceClassification, AdamW
113
+ from torch.utils.data import DataLoader, TensorDataset
114
+ from transformers import BertTokenizer
115
+ import gradio as gr
116
+ import pandas as pd
117
+ import os
118
+ import spaces
119
+ from spaces.zero.gradio import HTMLError
120
+
121
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ print(device)
123
+
124
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
125
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
126
+ model.to(device)
127
+
128
+ optimizer = AdamW(model.parameters(), lr=1e-5)
129
+
130
+ global_data = None
131
+
132
+ def load_data(file):
133
+ global global_data
134
+ df = pd.read_csv(file)
135
+ inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
136
+ labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
137
+ global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
138
+
139
+ print(global_data)
140
+
141
+ def get_dataloader(start, end, batch_size=8):
142
+ global global_data
143
+ subset = torch.utils.data.Subset(global_data, range(start, end))
144
+ return DataLoader(subset, batch_size=batch_size)
145
+
146
+ @spaces.GPU(duration=20)
147
+ def train_batch(dataloader):
148
+ model.train()
149
+ start_time = time.time()
150
+
151
+ for step, batch in enumerate(dataloader):
152
+ input_ids, attention_mask, labels = batch
153
+ input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
154
+
155
+ optimizer.zero_grad()
156
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
157
+ loss = outputs.loss
158
+ loss.backward()
159
+ optimizer.step()
160
+
161
+ elapsed_time = time.time() - start_time
162
+ if elapsed_time > 10:
163
+ print('Save checkpoint')
164
+ if not os.path.exists('./checkpoint'):
165
+ os.makedirs('./checkpoint')
166
+ torch.save(model.state_dict(), "./checkpoint/model.pt")
167
+
168
+ return False, "Checkpoint saved. Training paused."
169
+
170
+ return True, "Batch training completed."
171
+
172
+
173
+ def train_step(file=None, start_idx=0):
174
  if file:
175
  load_data(file)
176
  print(global_data)
177
+ start_idx = int(start_idx)
178
+ # Load lại checkpoint nếu tồn tại
179
+ if os.path.exists("./checkpoint/model.pt"):
180
+ print("Loading checkpoint...")
181
+ model.load_state_dict(torch.load("./checkpoint/model.pt"))
182
 
 
183
  batch_size = 8
184
  total_samples = len(global_data)
185
 
186
  counting = 0
187
  while start_idx < total_samples:
188
  print("Step:", counting)
189
+ print("Percent:", (start_idx) / total_samples * 100, "%")
190
  counting += 1
191
  end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
192
  dataloader = get_dataloader(start_idx, end_idx, batch_size)
 
194
  try:
195
  success, message = train_batch(dataloader)
196
  if not success:
197
+ return start_idx # Trả về start_idx nếu lỗi xảy ra
198
 
199
  except HTMLError as e:
200
  print("Exceeded GPU quota, retrying in 10 seconds...")
201
  time.sleep(10)
202
+ return start_idx # Trả về start_idx để lưu lại vị trí
203
 
204
  start_idx = end_idx
205
+
206
  if not os.path.exists('./checkpoint'):
207
  os.makedirs('./checkpoint')
208
  torch.save(model.state_dict(), "./checkpoint/model.pt")
209
+ return start_idx
210
+
211
 
212
  if __name__ == "__main__":
213
  iface = gr.Interface(
214
  fn=train_step,
215
+ inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
216
  outputs="text"
217
  )
218
  iface.launch()