agaresd commited on
Commit
e2272ff
·
verified ·
1 Parent(s): 8d8d0b6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -5
README.md CHANGED
@@ -9,14 +9,47 @@ pipeline_tag: text-classification
9
 
10
  This is the repo for Gen AI final project
11
 
 
 
 
 
 
 
12
 
13
  ## Usage
14
  ```python
15
- from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
16
 
17
- model = AutoModel.from_pretrained("agaresd/GEN-AI-Final-project")
18
- tokenizer = AutoTokenizer.from_pretrained("agaresd/GEN-AI-Final-project")
 
 
 
 
 
 
 
 
19
 
20
- inputs = tokenizer("Hello!", return_tensors="pt")
 
 
 
 
21
  outputs = model(**inputs)
22
- print(outputs)
 
 
 
 
 
 
 
 
 
 
9
 
10
  This is the repo for Gen AI final project
11
 
12
+ ## Info
13
+ Original code: https://github.com/hyunwoongko/transformer
14
+
15
+ My version: https://github.com/Agaresd47/transformer_SAE
16
+
17
+
18
 
19
  ## Usage
20
  ```python
21
+ import torch
22
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
23
+ import torch.nn.functional as F
24
+
25
+ # Load the model and tokenizer
26
+ model = AutoModelForSequenceClassification.from_pretrained("agaresd/your-model-name")
27
+ tokenizer = AutoTokenizer.from_pretrained("agaresd/your-model-name")
28
 
29
+ # Define the label mapping
30
+ label_mapping = {
31
+ 0: "no emotion",
32
+ 1: "anger ",
33
+ 2: "disgust ",
34
+ 3: "fear ",
35
+ 4: "Emotion: Happy",
36
+ 5: "Emotion: Sad",
37
+ 6: "Emotion: surprise"
38
+ }
39
 
40
+ # Input text
41
+ input_text = "happy"
42
+
43
+ # Tokenize and get model outputs
44
+ inputs = tokenizer(input_text, return_tensors="pt")
45
  outputs = model(**inputs)
46
+
47
+ # Get logits, apply softmax, and find the predicted class
48
+ logits = outputs.logits
49
+ probabilities = F.softmax(logits, dim=-1)
50
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
51
+
52
+ # Map the predicted class to a word
53
+ predicted_label = label_mapping[predicted_class]
54
+ print(f"Input: {input_text}")
55
+ print(f"Predicted Label: {predicted_label}")