solon commited on
Commit
19d03cc
ยท
0 Parent(s):

Initial commit: kf-deberta-gen Gradio demo

Browse files
Files changed (3) hide show
  1. README.md +50 -0
  2. app.py +143 -0
  3. requirements.txt +3 -0
README.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: kf-deberta-gen
3
+ emoji: ๐ŸŒ€
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # ๐ŸŒ€ kf-deberta-gen Demo
14
+
15
+ **Generative Diffusion BERT** - ํ•œ๊ตญ์–ด Diffusion ๊ธฐ๋ฐ˜ ์ƒ์„ฑ ์–ธ์–ด ๋ชจ๋ธ ๋ฐ๋ชจ
16
+
17
+ [![Model](https://img.shields.io/badge/๐Ÿค—%20Model-kf--deberta--gen-blue)](https://huggingface.co/solonsophy/kf-deberta-gen)
18
+ [![GitHub](https://img.shields.io/badge/GitHub-Repository-black)](https://github.com/hong-seongmin/GenerativeDiffusionBERT)
19
+
20
+
21
+ ---
22
+
23
+ ## ๊ฐœ์š”
24
+
25
+ ์ด Space๋Š” [solonsophy/kf-deberta-gen](https://huggingface.co/solonsophy/kf-deberta-gen) ๋ชจ๋ธ์˜ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
26
+
27
+ **Discrete Diffusion** ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต๋œ ์ด ๋ชจ๋ธ์€ ์งˆ๋ฌธ์— ๋Œ€ํ•ด **Iterative Denoising**์„ ํ†ตํ•ด ์ ์ง„์ ์œผ๋กœ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
28
+
29
+ ## ์‚ฌ์šฉ ๋ฐฉ๋ฒ•
30
+
31
+ 1. ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”
32
+ 2. ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•˜์„ธ์š”:
33
+ - **Steps**: ๋””๋…ธ์ด์ง• ์Šคํ… ์ˆ˜ (๋†’์„์ˆ˜๋ก ํ’ˆ์งˆโ†‘, ์†๋„โ†“)
34
+ - **Temperature**: ์ƒ์„ฑ ๋‹ค์–‘์„ฑ (๋†’์„์ˆ˜๋ก ์ฐฝ์˜์ )
35
+ - **Top-K**: ํ›„๋ณด ํ† ํฐ ์ˆ˜
36
+ 3. "์ƒ์„ฑ" ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์„ธ์š”
37
+
38
+ ## ๊ธฐ์ˆ  ์Šคํƒ
39
+
40
+ - **Base Model**: [kakaobank/kf-deberta-base](https://huggingface.co/kakaobank/kf-deberta-base)
41
+ - **Method**: MDLM (Masked Diffusion Language Model)
42
+ - **Framework**: Transformers, Gradio
43
+
44
+ ## ์˜ˆ์‹œ ์งˆ๋ฌธ
45
+
46
+ - ์ธ๊ณต์ง€๋Šฅ์ด๋ž€ ๋ฌด์—‡์ธ๊ฐ€์š”?
47
+ - ์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋•Œ?
48
+ - ํŒŒ์ด์ฌ์„ ๋ฐฐ์šฐ๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?
49
+
50
+ ---
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+
6
+ # ์„ค์ •
7
+ MODEL_ID = "solonsophy/kf-deberta-gen"
8
+ MAX_LEN = 256
9
+ Q_MAX_LEN = 100
10
+
11
+ # ๋ชจ๋ธ ๋กœ๋“œ
12
+ print("๐Ÿ”„ Loading model...")
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = model.to(device)
17
+ model.eval()
18
+ print(f"โœ… Model loaded on {device}")
19
+
20
+ MASK_ID = tokenizer.mask_token_id
21
+ PAD_ID = tokenizer.pad_token_id
22
+ CLS_ID = tokenizer.cls_token_id
23
+ SEP_ID = tokenizer.sep_token_id
24
+
25
+
26
+ def generate_response(question, num_steps, temperature, top_k, max_answer_len):
27
+ """Diffusion ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€ ์ƒ์„ฑ"""
28
+ if not question.strip():
29
+ return "์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
30
+
31
+ # ์งˆ๋ฌธ ํ† ํฐํ™”
32
+ q_tokens = tokenizer.encode(question, add_special_tokens=False)[:Q_MAX_LEN]
33
+
34
+ # ์ดˆ๊ธฐ: [CLS] Q [SEP] [MASK]*N
35
+ input_ids = [CLS_ID] + q_tokens + [SEP_ID] + [MASK_ID] * max_answer_len
36
+ input_ids = input_ids[:MAX_LEN]
37
+
38
+ answer_start = len(q_tokens) + 2
39
+ answer_end = len(input_ids)
40
+
41
+ input_ids = torch.tensor([input_ids], device=device)
42
+ attention_mask = torch.ones_like(input_ids)
43
+
44
+ # Iterative denoising
45
+ for step in range(num_steps):
46
+ with torch.no_grad():
47
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
48
+ logits = outputs.logits
49
+
50
+ # ๋งˆ์Šคํฌ ์œ„์น˜ ์ฐพ๊ธฐ
51
+ mask_positions = (input_ids[0, answer_start:answer_end] == MASK_ID).nonzero(as_tuple=True)[0]
52
+ mask_positions = mask_positions + answer_start
53
+
54
+ if len(mask_positions) == 0:
55
+ break
56
+
57
+ # ์ด๋ฒˆ ์Šคํ…์—์„œ unmaskํ•  ๊ฐœ์ˆ˜
58
+ remaining_steps = num_steps - step
59
+ tokens_per_step = max(1, len(mask_positions) // remaining_steps)
60
+
61
+ # logits ์ฒ˜๋ฆฌ
62
+ mask_logits = logits[0, mask_positions] / temperature
63
+
64
+ # Top-k filtering
65
+ if top_k > 0:
66
+ top_k_values, _ = torch.topk(mask_logits, min(top_k, mask_logits.size(-1)), dim=-1)
67
+ threshold = top_k_values[:, -1].unsqueeze(-1)
68
+ mask_logits = torch.where(mask_logits < threshold, float('-inf'), mask_logits)
69
+
70
+ # ์ƒ˜ํ”Œ๋ง
71
+ probs = F.softmax(mask_logits, dim=-1)
72
+ sampled_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
73
+
74
+ # Confidence
75
+ confidences = probs.gather(1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
76
+
77
+ # Confidence ๊ธฐ๋ฐ˜ unmask
78
+ _, top_indices = torch.topk(confidences, min(tokens_per_step, len(confidences)))
79
+
80
+ selected_positions = mask_positions[top_indices]
81
+ selected_tokens = sampled_tokens[top_indices]
82
+ input_ids[0, selected_positions] = selected_tokens
83
+
84
+ # ๊ฒฐ๊ณผ ์ถ”์ถœ
85
+ answer_tokens = input_ids[0, answer_start:answer_end]
86
+ valid_mask = (answer_tokens != MASK_ID) & (answer_tokens != PAD_ID)
87
+ answer_tokens = answer_tokens[valid_mask]
88
+
89
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
90
+ return answer.strip() if answer.strip() else "(์ƒ์„ฑ ์‹คํŒจ)"
91
+
92
+
93
+ # Gradio UI
94
+ with gr.Blocks(title="kf-deberta-gen Demo", theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown("""
96
+ # ๐ŸŒ€ kf-deberta-gen Demo
97
+
98
+ **Generative Diffusion BERT** - ํ•œ๊ตญ์–ด Diffusion ๊ธฐ๋ฐ˜ ์ƒ์„ฑ ์–ธ์–ด ๋ชจ๋ธ (์‹คํ—˜์ )
99
+
100
+ > โš ๏ธ ์ด ๋ชจ๋ธ์€ PoC ๋‹จ๊ณ„์ž…๋‹ˆ๋‹ค. ์ƒ์„ฑ ํ’ˆ์งˆ์ด ๋ถˆ์•ˆ์ •ํ•˜๋ฉฐ ๋ฐ˜๋ณต ์ƒ์„ฑ ๋“ฑ์˜ ๋ฌธ์ œ๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
101
+ """)
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=2):
105
+ question_input = gr.Textbox(
106
+ label="์งˆ๋ฌธ",
107
+ placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
108
+ lines=2
109
+ )
110
+ submit_btn = gr.Button("๐Ÿš€ ์ƒ์„ฑ", variant="primary")
111
+
112
+ with gr.Column(scale=1):
113
+ num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps")
114
+ temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="Temperature")
115
+ top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K")
116
+ max_len = gr.Slider(20, 150, value=80, step=10, label="Max Answer Length")
117
+
118
+ output = gr.Textbox(label="๋‹ต๋ณ€", lines=5)
119
+
120
+ gr.Examples(
121
+ examples=[
122
+ ["์ธ๊ณต์ง€๋Šฅ์ด๋ž€ ๋ฌด์—‡์ธ๊ฐ€์š”?"],
123
+ ["์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋•Œ?"],
124
+ ["ํŒŒ์ด์ฌ์„ ๋ฐฐ์šฐ๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?"],
125
+ ["์•ˆ๋…•ํ•˜์„ธ์š”"],
126
+ ],
127
+ inputs=question_input
128
+ )
129
+
130
+ submit_btn.click(
131
+ fn=generate_response,
132
+ inputs=[question_input, num_steps, temperature, top_k, max_len],
133
+ outputs=output
134
+ )
135
+
136
+ question_input.submit(
137
+ fn=generate_response,
138
+ inputs=[question_input, num_steps, temperature, top_k, max_len],
139
+ outputs=output
140
+ )
141
+
142
+ if __name__ == "__main__":
143
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio>=4.0.0