yagnik12 commited on
Commit
c847ada
·
verified ·
1 Parent(s): b7b9e5f

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +15 -6
train.py CHANGED
@@ -2,26 +2,33 @@ from datasets import load_dataset
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
3
  import os
4
 
 
5
  dataset = load_dataset("HanxiGuo/BiScope_Data")
6
- model_name = "distilbert-base-uncased"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
8
 
9
  def tokenize(batch):
10
  return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)
11
 
12
  tokenized = dataset.map(tokenize, batched=True)
13
 
14
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
 
15
 
 
16
  training_args = TrainingArguments(
17
  output_dir="./results",
18
  evaluation_strategy="epoch",
19
  save_strategy="epoch",
20
- num_train_epochs=1,
21
  per_device_train_batch_size=16,
22
  per_device_eval_batch_size=16,
23
  push_to_hub=True,
24
- hub_model_id="yagnik12/AI_Text_Detecter_HanxiGuo_BiScope-Data", # ✅ model repo, not Space
25
  hub_token=os.getenv("HF_TOKEN"),
26
  )
27
 
@@ -33,5 +40,7 @@ trainer = Trainer(
33
  tokenizer=tokenizer,
34
  )
35
 
 
36
  trainer.train()
37
- trainer.push_to_hub()
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
3
  import os
4
 
5
+ # ✅ Dataset
6
  dataset = load_dataset("HanxiGuo/BiScope_Data")
7
+
8
+ # Base model
9
+ BASE_MODEL = "distilbert-base-uncased"
10
+ MODEL_REPO = "yagnik12/AI_Text_Detecter_HanxiGuo_BiScope-Data"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
13
 
14
  def tokenize(batch):
15
  return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)
16
 
17
  tokenized = dataset.map(tokenize, batched=True)
18
 
19
+ # Model
20
+ model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=2)
21
 
22
+ # ✅ Training setup
23
  training_args = TrainingArguments(
24
  output_dir="./results",
25
  evaluation_strategy="epoch",
26
  save_strategy="epoch",
27
+ num_train_epochs=1, # start small for demo
28
  per_device_train_batch_size=16,
29
  per_device_eval_batch_size=16,
30
  push_to_hub=True,
31
+ hub_model_id=MODEL_REPO,
32
  hub_token=os.getenv("HF_TOKEN"),
33
  )
34
 
 
40
  tokenizer=tokenizer,
41
  )
42
 
43
+ # ✅ Train & push
44
  trainer.train()
45
+ trainer.push_to_hub()
46
+ print(f"✅ Model pushed to https://huggingface.co/{MODEL_REPO}")