cnmoro commited on
Commit
d6611f1
·
verified ·
1 Parent(s): 7d05efd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -3
README.md CHANGED
@@ -1,3 +1,60 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - cnmoro/QuestionClassification-v2
5
+ language:
6
+ - en
7
+ - pt
8
+ tags:
9
+ - classification
10
+ - questioning
11
+ - directed
12
+ - generic
13
+ pipeline_tag: text-classification
14
+ base_model:
15
+ - ibm-granite/granite-embedding-30m-english
16
+ library_name: transformers
17
+ ---
18
+ A finetuned version of ibm-granite/granite-embedding-30m-english.
19
+
20
+ The goal is to classify questions into "Directed" or "Generic".
21
+
22
+ If a question is not directed, we would change the actions we perform on a RAG pipeline (if it is generic, semantic search wouldn't be useful directly; e.g. asking for a summary).
23
+
24
+ (Class 0 is Generic; Class 1 is Directed)
25
+
26
+ The accuracy achieved during training was 94%.
27
+
28
+ ```python
29
+ import torch
30
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
31
+
32
+ model_id = "cnmoro/granite-question-classifier"
33
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ model.eval()
36
+
37
+ def predict_question_category(question):
38
+ inputs = tokenizer.encode_plus(
39
+ question,
40
+ add_special_tokens=True,
41
+ max_length=512,
42
+ return_tensors="pt",
43
+ truncation=True
44
+ )
45
+
46
+ input_ids = inputs["input_ids"]
47
+ attention_mask = inputs["attention_mask"]
48
+
49
+ with torch.no_grad():
50
+ outputs = model(input_ids, attention_mask=attention_mask)
51
+ logits = outputs.logits.squeeze(-1)
52
+ print(logits)
53
+ prediction = (logits > 0).float().item()
54
+
55
+ # Map prediction to category
56
+ return "directed" if prediction == 1.0 else "generic"
57
+
58
+ predict_question_category("Qual o resumo do texto?") # generic
59
+ predict_question_category("Qual foi a crítica que o autor recebeu do jornal, em relação a sua opinião?") # directed
60
+ ```