unijoh commited on
Commit
ac86473
·
verified ·
1 Parent(s): 293c12b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -118
app.py CHANGED
@@ -1,118 +1,118 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import pandas as pd
5
- from transformers import AutoTokenizer, AutoModelForTokenClassification
6
-
7
- MODEL_ID = "YOUR_USERNAME/YOUR_MODEL_REPO"
8
- TAGS_FILEPATH = "Sosialurin-GOLD_tags.csv"
9
-
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
- model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
12
-
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- model.to(device)
15
- model.eval()
16
-
17
- def load_tag_mappings(tags_filepath):
18
- tags_df = pd.read_csv(tags_filepath)
19
- features_to_tag = {
20
- tuple(row[1:].values.astype(int)): row["Original Tag"]
21
- for _, row in tags_df.iterrows()
22
- }
23
- vec_len = len(tags_df.columns) - 1
24
- return features_to_tag, vec_len
25
-
26
- features_to_tag, VEC_LEN = load_tag_mappings(TAGS_FILEPATH)
27
-
28
- # Use the SAME intervals as your demo.py (keep these consistent!)
29
- intervals = (
30
- (15, 28),
31
- (29, 32),
32
- (33, 35),
33
- (36, 40),
34
- (41, 42),
35
- (43, 44),
36
- (45, 49),
37
- (50, 52),
38
- (53, 58),
39
- (59, 61),
40
- (62, 64),
41
- (65, 68),
42
- (69, 70),
43
- )
44
-
45
- def vector_to_tag(vec):
46
- return features_to_tag.get(tuple(vec.int().tolist()), "Unknown Tag")
47
-
48
- def tag_sentence(sentence: str):
49
- sentence = sentence.strip()
50
- if not sentence:
51
- return ""
52
-
53
- tokens = sentence.split()
54
-
55
- enc = tokenizer(
56
- tokens,
57
- is_split_into_words=True,
58
- add_special_tokens=True,
59
- max_length=128,
60
- padding="max_length",
61
- truncation=True,
62
- return_attention_mask=True,
63
- return_tensors="pt"
64
- )
65
-
66
- input_ids = enc["input_ids"].to(device)
67
- attention_mask = enc["attention_mask"].to(device)
68
- word_ids = enc.word_ids(batch_index=0)
69
-
70
- # begin token mask
71
- begin = []
72
- last = None
73
- for wid in word_ids:
74
- if wid is None:
75
- begin.append(0)
76
- elif wid != last:
77
- begin.append(1)
78
- else:
79
- begin.append(0)
80
- last = wid
81
-
82
- with torch.no_grad():
83
- out = model(input_ids=input_ids, attention_mask=attention_mask)
84
- logits = out.logits[0] # [seq_len, num_labels]
85
-
86
- lines = []
87
- for i in range(logits.shape[0]):
88
- if attention_mask[0, i].item() != 1 or begin[i] != 1:
89
- continue
90
-
91
- pred = logits[i]
92
- vec = torch.zeros(VEC_LEN, device=logits.device)
93
-
94
- # Word type in [0..14]
95
- wt = torch.argmax(pred[0:15]).item()
96
- vec[wt] = 1
97
-
98
- # Interval decoding
99
- for a, b in intervals:
100
- seg = pred[a:b+1]
101
- k = torch.argmax(seg).item()
102
- vec[a + k] = 1
103
-
104
- wid = word_ids[i]
105
- word = tokens[wid] if wid is not None and wid < len(tokens) else "<UNK>"
106
- lines.append(f"{word}\t{vector_to_tag(vec)}")
107
-
108
- return "\n".join(lines)
109
-
110
- demo = gr.Interface(
111
- fn=tag_sentence,
112
- inputs=gr.Textbox(lines=2, label="Sentence"),
113
- outputs=gr.Textbox(lines=12, label="Token\\tTag"),
114
- title="Faroese POS Tagger (Demo)"
115
- )
116
-
117
- if __name__ == "__main__":
118
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
6
+
7
+ MODEL_ID = "Setur/BRAGD"
8
+ TAGS_FILEPATH = "Sosialurin-GOLD_tags.csv"
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
+ model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ def load_tag_mappings(tags_filepath):
18
+ tags_df = pd.read_csv(tags_filepath)
19
+ features_to_tag = {
20
+ tuple(row[1:].values.astype(int)): row["Original Tag"]
21
+ for _, row in tags_df.iterrows()
22
+ }
23
+ vec_len = len(tags_df.columns) - 1
24
+ return features_to_tag, vec_len
25
+
26
+ features_to_tag, VEC_LEN = load_tag_mappings(TAGS_FILEPATH)
27
+
28
+ # Use the SAME intervals as your demo.py (keep these consistent!)
29
+ intervals = (
30
+ (15, 28),
31
+ (29, 32),
32
+ (33, 35),
33
+ (36, 40),
34
+ (41, 42),
35
+ (43, 44),
36
+ (45, 49),
37
+ (50, 52),
38
+ (53, 58),
39
+ (59, 61),
40
+ (62, 64),
41
+ (65, 68),
42
+ (69, 70),
43
+ )
44
+
45
+ def vector_to_tag(vec):
46
+ return features_to_tag.get(tuple(vec.int().tolist()), "Unknown Tag")
47
+
48
+ def tag_sentence(sentence: str):
49
+ sentence = sentence.strip()
50
+ if not sentence:
51
+ return ""
52
+
53
+ tokens = sentence.split()
54
+
55
+ enc = tokenizer(
56
+ tokens,
57
+ is_split_into_words=True,
58
+ add_special_tokens=True,
59
+ max_length=128,
60
+ padding="max_length",
61
+ truncation=True,
62
+ return_attention_mask=True,
63
+ return_tensors="pt"
64
+ )
65
+
66
+ input_ids = enc["input_ids"].to(device)
67
+ attention_mask = enc["attention_mask"].to(device)
68
+ word_ids = enc.word_ids(batch_index=0)
69
+
70
+ # begin token mask
71
+ begin = []
72
+ last = None
73
+ for wid in word_ids:
74
+ if wid is None:
75
+ begin.append(0)
76
+ elif wid != last:
77
+ begin.append(1)
78
+ else:
79
+ begin.append(0)
80
+ last = wid
81
+
82
+ with torch.no_grad():
83
+ out = model(input_ids=input_ids, attention_mask=attention_mask)
84
+ logits = out.logits[0] # [seq_len, num_labels]
85
+
86
+ lines = []
87
+ for i in range(logits.shape[0]):
88
+ if attention_mask[0, i].item() != 1 or begin[i] != 1:
89
+ continue
90
+
91
+ pred = logits[i]
92
+ vec = torch.zeros(VEC_LEN, device=logits.device)
93
+
94
+ # Word type in [0..14]
95
+ wt = torch.argmax(pred[0:15]).item()
96
+ vec[wt] = 1
97
+
98
+ # Interval decoding
99
+ for a, b in intervals:
100
+ seg = pred[a:b+1]
101
+ k = torch.argmax(seg).item()
102
+ vec[a + k] = 1
103
+
104
+ wid = word_ids[i]
105
+ word = tokens[wid] if wid is not None and wid < len(tokens) else "<UNK>"
106
+ lines.append(f"{word}\t{vector_to_tag(vec)}")
107
+
108
+ return "\n".join(lines)
109
+
110
+ demo = gr.Interface(
111
+ fn=tag_sentence,
112
+ inputs=gr.Textbox(lines=2, label="Sentence"),
113
+ outputs=gr.Textbox(lines=12, label="Token\\tTag"),
114
+ title="Faroese POS Tagger (Demo)"
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch()