Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| # ์ค์ | |
| MODEL_ID = "solonsophy/kf-deberta-gen" # ํ์ธํ๋๋ ๋ชจ๋ธ | |
| BASE_MODEL_ID = "kakaobank/kf-deberta-base" # ๊ธฐ๋ฐ ๋ชจ๋ธ (ํ ํฌ๋์ด์ ์ฉ) | |
| MAX_LEN = 256 | |
| Q_MAX_LEN = 100 | |
| # ๋ชจ๋ธ ๋ก๋ | |
| print("๐ Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) # ๊ธฐ๋ฐ ๋ชจ๋ธ์์ ํ ํฌ๋์ด์ ๋ก๋ | |
| model = AutoModelForMaskedLM.from_pretrained(MODEL_ID) # ํ์ธํ๋๋ ๊ฐ์ค์น ๋ก๋ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| print(f"โ Model loaded on {device}") | |
| MASK_ID = tokenizer.mask_token_id | |
| PAD_ID = tokenizer.pad_token_id | |
| CLS_ID = tokenizer.cls_token_id | |
| SEP_ID = tokenizer.sep_token_id | |
| def generate_response(question, num_steps, temperature, top_k, max_answer_len): | |
| """Diffusion ๊ธฐ๋ฐ ๋ต๋ณ ์์ฑ""" | |
| if not question.strip(): | |
| return "์ง๋ฌธ์ ์ ๋ ฅํด์ฃผ์ธ์." | |
| # ์ง๋ฌธ ํ ํฐํ | |
| q_tokens = tokenizer.encode(question, add_special_tokens=False)[:Q_MAX_LEN] | |
| # ์ด๊ธฐ: [CLS] Q [SEP] [MASK]*N | |
| input_ids = [CLS_ID] + q_tokens + [SEP_ID] + [MASK_ID] * max_answer_len | |
| input_ids = input_ids[:MAX_LEN] | |
| answer_start = len(q_tokens) + 2 | |
| answer_end = len(input_ids) | |
| input_ids = torch.tensor([input_ids], device=device) | |
| attention_mask = torch.ones_like(input_ids) | |
| # Iterative denoising | |
| for step in range(num_steps): | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| # ๋ง์คํฌ ์์น ์ฐพ๊ธฐ | |
| mask_positions = (input_ids[0, answer_start:answer_end] == MASK_ID).nonzero(as_tuple=True)[0] | |
| mask_positions = mask_positions + answer_start | |
| if len(mask_positions) == 0: | |
| break | |
| # ์ด๋ฒ ์คํ ์์ unmaskํ ๊ฐ์ | |
| remaining_steps = num_steps - step | |
| tokens_per_step = max(1, len(mask_positions) // remaining_steps) | |
| # logits ์ฒ๋ฆฌ | |
| mask_logits = logits[0, mask_positions] / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| top_k_values, _ = torch.topk(mask_logits, min(top_k, mask_logits.size(-1)), dim=-1) | |
| threshold = top_k_values[:, -1].unsqueeze(-1) | |
| mask_logits = torch.where(mask_logits < threshold, float('-inf'), mask_logits) | |
| # ์ํ๋ง | |
| probs = F.softmax(mask_logits, dim=-1) | |
| sampled_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) | |
| # Confidence | |
| confidences = probs.gather(1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| # Confidence ๊ธฐ๋ฐ unmask | |
| _, top_indices = torch.topk(confidences, min(tokens_per_step, len(confidences))) | |
| selected_positions = mask_positions[top_indices] | |
| selected_tokens = sampled_tokens[top_indices] | |
| input_ids[0, selected_positions] = selected_tokens | |
| # ๊ฒฐ๊ณผ ์ถ์ถ | |
| answer_tokens = input_ids[0, answer_start:answer_end] | |
| valid_mask = (answer_tokens != MASK_ID) & (answer_tokens != PAD_ID) | |
| answer_tokens = answer_tokens[valid_mask] | |
| answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) | |
| return answer.strip() if answer.strip() else "(์์ฑ ์คํจ)" | |
| # Gradio UI | |
| with gr.Blocks(title="kf-deberta-gen Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ๐ kf-deberta-gen Demo | |
| **Generative Diffusion BERT** - ํ๊ตญ์ด Diffusion ๊ธฐ๋ฐ ์์ฑ ์ธ์ด ๋ชจ๋ธ (์คํ์ ) | |
| > โ ๏ธ ์ด ๋ชจ๋ธ์ PoC ๋จ๊ณ์ ๋๋ค. ์์ฑ ํ์ง์ด ๋ถ์์ ํ๋ฉฐ ๋ฐ๋ณต ์์ฑ ๋ฑ์ ๋ฌธ์ ๊ฐ ์์ ์ ์์ต๋๋ค. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="์ง๋ฌธ", | |
| placeholder="์ง๋ฌธ์ ์ ๋ ฅํ์ธ์...", | |
| lines=2 | |
| ) | |
| submit_btn = gr.Button("๐ ์์ฑ", variant="primary") | |
| with gr.Column(scale=1): | |
| num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps") | |
| temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="Temperature") | |
| top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K") | |
| max_len = gr.Slider(20, 150, value=80, step=10, label="Max Answer Length") | |
| output = gr.Textbox(label="๋ต๋ณ", lines=5) | |
| gr.Examples( | |
| examples=[ | |
| ["์ค๋ ๋ ์จ ์ด๋?"], | |
| ["ํ์ด์ฌ์ ๋ฐฐ์ฐ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?"], | |
| ["์๋ ํ์ธ์"], | |
| ], | |
| inputs=question_input | |
| ) | |
| submit_btn.click( | |
| fn=generate_response, | |
| inputs=[question_input, num_steps, temperature, top_k, max_len], | |
| outputs=output | |
| ) | |
| question_input.submit( | |
| fn=generate_response, | |
| inputs=[question_input, num_steps, temperature, top_k, max_len], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |