CocoRoF commited on
Commit
08b0a3e
·
verified ·
1 Parent(s): 795c64c

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: x2bee/ModernBert_MLM_kotoken_v01
3
+ model-index:
4
+ - name: plateer_classifier_ModernBERT_v01
5
+ results: []
6
+ ---
7
+
8
+ # plateer_classifier_ModernBERT_v01
9
+
10
+ This model is a fine-tuned version of [x2bee/ModernBert_MLM_kotoken_v01](https://huggingface.co/x2bee/ModernBert_MLM_kotoken_v01) on [x2bee/plateer_category_data](https://huggingface.co/datasets/x2bee/plateer_category_data). <br>
11
+ It achieves the following results on the evaluation set:
12
+ - Loss: 0.3379
13
+
14
+ #### Example Use
15
+ ```python
16
+ import joblib;
17
+ from huggingface_hub import hf_hub_download;
18
+ from peft import PeftModel, PeftConfig;
19
+ from transformers import AutoTokenizer, TextClassificationPipeline, AutoModelForSequenceClassification;
20
+ from huggingface_hub import HfApi, login
21
+
22
+ # need hgf token for accessing X2BEE repo.
23
+ with open('./api_key/HGF_TOKEN.txt', 'r') as hgf:
24
+ login(token=hgf.read())
25
+ api = HfApi()
26
+ repo_id = "x2bee/plateer_classifier_ModernBERT_v01"
27
+ data_id = "x2bee/plateer_category_data"
28
+
29
+ # Load Config, Tokenizer, Label_Encoder
30
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="last-checkpoint")
31
+ label_encoder_file = hf_hub_download(repo_id=data_id, repo_type="dataset", filename="label_encoder.joblib")
32
+ label_encoder = joblib.load(label_encoder_file)
33
+
34
+ # Load Model
35
+ model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="last-checkpoint")
36
+
37
+ import torch
38
+ class TextClassificationPipeline(TextClassificationPipeline):
39
+ def __call__(self, inputs, top_k=5, **kwargs):
40
+ inputs = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True, max_length=512, **kwargs)
41
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
42
+
43
+ with torch.no_grad():
44
+ outputs = self.model(**inputs)
45
+
46
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
47
+ scores, indices = torch.topk(probs, top_k, dim=-1)
48
+
49
+ results = []
50
+ for batch_idx in range(indices.shape[0]):
51
+ batch_results = []
52
+ for score, idx in zip(scores[batch_idx], indices[batch_idx]):
53
+ temp_list = []
54
+ label = self.model.config.id2label[idx.item()]
55
+ label = int(label.split("_")[1])
56
+ temp_list.append(label)
57
+ predicted_class = label_encoder.inverse_transform(temp_list)[0]
58
+
59
+ batch_results.append({
60
+ "label": label,
61
+ "label_decode": predicted_class,
62
+ "score": score.item(),
63
+ })
64
+ results.append(batch_results)
65
+
66
+ return results
67
+
68
+ classifier_model = TextClassificationPipeline(tokenizer=tokenizer, model=model)
69
+
70
+ def plateer_classifier(text, top_k=3):
71
+ result = classifier_model(text, top_k=top_k)
72
+ return result
73
+
74
+ # run
75
+ result = plateer_classifier("겨울 등산에서 사용할 옷")[0]
76
+ print(result)
77
+
78
+ # result
79
+ -----------Category-----------
80
+ {'label': 2, 'label_decode': '기능성의류', 'score': 0.9214227795600891}
81
+ {'label': 8, 'label_decode': '스포츠', 'score': 0.07054771482944489}
82
+ {'label': 15, 'label_decode': '패션/의류/잡화', 'score': 0.0036312134470790625}
83
+
84
+ ```
85
+
86
+
87
+ ### Training hyperparameters
88
+
89
+ The following hyperparameters were used during training:
90
+ - learning_rate: 2e-4
91
+ - train_batch_size: 16
92
+ - eval_batch_size: 16
93
+ - seed: 42
94
+ - gradient_accumulation_steps: 4
95
+ - total_train_batch_size: 64
96
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
97
+ - lr_scheduler_type: linear
98
+ - lr_scheduler_warmup_steps: 10000
99
+ - num_epochs: 3
100
+
101
+ ### Framework versions
102
+
103
+ - Transformers 4.48