gpt2-bert / app.py
openfree's picture
Create app.py
2aa4818 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import numpy as np
class MergedModelTester:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_id="openfree/gpt2-bert", progress=gr.Progress()):
"""๋ณ‘ํ•ฉ ๋ชจ๋ธ ๋กœ๋“œ"""
try:
progress(0.2, desc="ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์ค‘...")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
progress(0.5, desc="๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
device_map="auto" if self.device.type == 'cuda' else None
)
if self.device.type == 'cpu':
self.model = self.model.to(self.device)
self.model.eval()
progress(1.0, desc="์™„๋ฃŒ!")
# ๋ชจ๋ธ ์ •๋ณด
num_params = sum(p.numel() for p in self.model.parameters())
return f"""โœ… ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!
- ๋ชจ๋ธ: {model_id}
- ํŒŒ๋ผ๋ฏธํ„ฐ: {num_params:,}
- ๋””๋ฐ”์ด์Šค: {self.device}"""
except Exception as e:
return f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
def generate_text(self, prompt, max_length=100, temperature=0.8,
top_p=0.9, repetition_penalty=1.2, progress=gr.Progress()):
"""ํ…์ŠคํŠธ ์ƒ์„ฑ"""
if self.model is None:
return "๋จผ์ € ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜์„ธ์š”!", None, None
try:
progress(0.3, desc="ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘...")
# ์ž…๋ ฅ ํ† ํฐํ™”
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# ์ƒ์„ฑ ์‹œ์ž‘ ์‹œ๊ฐ„
start_time = time.time()
# ํ…์ŠคํŠธ ์ƒ์„ฑ
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# ์ƒ์„ฑ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
generation_time = time.time() - start_time
# ๋””์ฝ”๋”ฉ
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# ํ†ต๊ณ„ ์ •๋ณด
input_tokens = len(inputs['input_ids'][0])
output_tokens = len(outputs[0])
new_tokens = output_tokens - input_tokens
stats = f"""๐Ÿ“Š ์ƒ์„ฑ ํ†ต๊ณ„:
- ์ž…๋ ฅ ํ† ํฐ: {input_tokens}
- ์ƒ์„ฑ ํ† ํฐ: {new_tokens}
- ์ „์ฒด ํ† ํฐ: {output_tokens}
- ์ƒ์„ฑ ์‹œ๊ฐ„: {generation_time:.2f}์ดˆ
- ์†๋„: {new_tokens/generation_time:.1f} tokens/sec"""
progress(1.0, desc="์™„๋ฃŒ!")
return generated_text, stats, None
except Exception as e:
return f"โŒ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}", None, str(e)
def compare_with_parents(self, prompt, max_length=50, progress=gr.Progress()):
"""๋ถ€๋ชจ ๋ชจ๋ธ๋“ค๊ณผ ๋น„๊ต"""
results = {}
# GPT-2 (๋ถ€๋ชจ 1)
try:
progress(0.1, desc="GPT-2 ๋กœ๋“œ ์ค‘...")
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2").to(self.device)
progress(0.3, desc="GPT-2 ์ƒ์„ฑ ์ค‘...")
inputs = gpt2_tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = gpt2_model.generate(**inputs, max_new_tokens=max_length, do_sample=True)
results['gpt2'] = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
del gpt2_model
except Exception as e:
results['gpt2'] = f"๋กœ๋“œ ์‹คํŒจ: {str(e)}"
# BERT๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด ์•„๋‹ˆ๋ฏ€๋กœ ์ œ์™ธ
results['bert'] = "BERT๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด ์•„๋‹™๋‹ˆ๋‹ค (์ธ์ฝ”๋” ์ „์šฉ)"
# ๋ณ‘ํ•ฉ ๋ชจ๋ธ
try:
progress(0.6, desc="๋ณ‘ํ•ฉ ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘...")
if self.model is None:
self.load_model()
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=max_length, do_sample=True)
results['merged'] = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
results['merged'] = f"์ƒ์„ฑ ์‹คํŒจ: {str(e)}"
progress(1.0, desc="์™„๋ฃŒ!")
# ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
comparison = f"""๐Ÿ”„ ๋ชจ๋ธ ๋น„๊ต ๊ฒฐ๊ณผ:
**GPT-2 (๋ถ€๋ชจ 1):**
{results['gpt2']}
**BERT (๋ถ€๋ชจ 2):**
{results['bert']}
**๋ณ‘ํ•ฉ ๋ชจ๋ธ (openfree/gpt2-bert):**
{results['merged']}"""
return comparison
# ์ „์—ญ ์ธ์Šคํ„ด์Šค
tester = MergedModelTester()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(title="GPT2-BERT ๋ณ‘ํ•ฉ ๋ชจ๋ธ ํ…Œ์Šคํ„ฐ") as demo:
gr.Markdown("""
# ๐Ÿงฌ GPT2-BERT ๋ณ‘ํ•ฉ ๋ชจ๋ธ ํ…Œ์Šคํ„ฐ
์ง„ํ™”์  ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๋ณ‘ํ•ฉ๋œ [openfree/gpt2-bert](https://huggingface.co/openfree/gpt2-bert) ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.
## ๐Ÿ“Š ๋ชจ๋ธ ์ •๋ณด
- **๋ถ€๋ชจ 1**: openai-community/gpt2
- **๋ถ€๋ชจ 2**: google-bert/bert-base-uncased
- **๋ณ‘ํ•ฉ ๋ฐฉ๋ฒ•**: SLERP (์ง„ํ™”์  ์ตœ์ ํ™”)
- **์ตœ์ข… ์„ฑ๋Šฅ**: 82-84% accuracy
""")
with gr.Tab("๐Ÿš€ ๋น ๋ฅธ ํ…Œ์ŠคํŠธ"):
with gr.Row():
with gr.Column():
load_btn = gr.Button("๐Ÿ“ฅ ๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
load_status = gr.Textbox(label="๋กœ๋“œ ์ƒํƒœ", lines=4)
prompt_input = gr.Textbox(
label="ํ”„๋กฌํ”„ํŠธ",
placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
value="The future of AI is",
lines=3
)
with gr.Row():
max_length = gr.Slider(20, 200, 100, label="์ตœ๋Œ€ ๊ธธ์ด")
temperature = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
with gr.Row():
top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p")
rep_penalty = gr.Slider(1.0, 2.0, 1.2, label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ")
generate_btn = gr.Button("โœจ ํ…์ŠคํŠธ ์ƒ์„ฑ", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ", lines=10)
stats_text = gr.Textbox(label="์ƒ์„ฑ ํ†ต๊ณ„", lines=6)
with gr.Tab("๐Ÿ”ฌ ๋ชจ๋ธ ๋น„๊ต"):
compare_prompt = gr.Textbox(
label="๋น„๊ตํ•  ํ”„๋กฌํ”„ํŠธ",
value="Once upon a time",
lines=2
)
compare_length = gr.Slider(20, 100, 50, label="์ƒ์„ฑ ๊ธธ์ด")
compare_btn = gr.Button("๐Ÿ”„ ๋ถ€๋ชจ ๋ชจ๋ธ๊ณผ ๋น„๊ต", variant="primary")
comparison_output = gr.Textbox(label="๋น„๊ต ๊ฒฐ๊ณผ", lines=20)
with gr.Tab("๐Ÿงช ๊ณ ๊ธ‰ ํ…Œ์ŠคํŠธ"):
gr.Markdown("### ๋‹ค์–‘ํ•œ ํƒœ์Šคํฌ ํ…Œ์ŠคํŠธ")
task_type = gr.Radio(
["์ด์•ผ๊ธฐ ์ƒ์„ฑ", "์งˆ๋ฌธ ๋‹ต๋ณ€", "์ฝ”๋“œ ์ƒ์„ฑ", "์‹œ ์ž‘์„ฑ"],
label="ํƒœ์Šคํฌ ์„ ํƒ",
value="์ด์•ผ๊ธฐ ์ƒ์„ฑ"
)
task_prompts = {
"์ด์•ผ๊ธฐ ์ƒ์„ฑ": "In a distant galaxy, a young explorer discovered",
"์งˆ๋ฌธ ๋‹ต๋ณ€": "Q: What is machine learning?\nA:",
"์ฝ”๋“œ ์ƒ์„ฑ": "# Python function to calculate fibonacci\ndef fibonacci(n):",
"์‹œ ์ž‘์„ฑ": "Roses are red,\nViolets are blue,"
}
def update_prompt(task):
return task_prompts.get(task, "")
task_prompt = gr.Textbox(label="ํƒœ์Šคํฌ ํ”„๋กฌํ”„ํŠธ", lines=3)
task_output = gr.Textbox(label="๊ฒฐ๊ณผ", lines=10)
task_btn = gr.Button("๐ŸŽฏ ํƒœ์Šคํฌ ์‹คํ–‰", variant="primary")
task_type.change(update_prompt, task_type, task_prompt)
with gr.Tab("๐Ÿ“ˆ ์„ฑ๋Šฅ ๋ถ„์„"):
gr.Markdown("""
### ์ง„ํ™” ์‹คํ—˜ ๊ฒฐ๊ณผ
| ๋ฉ”ํŠธ๋ฆญ | ๊ฐ’ |
|--------|-----|
| ์ดˆ๊ธฐ ์„ฑ๋Šฅ | 10.56% |
| ์ตœ์ข… ์„ฑ๋Šฅ | 82-84% |
| ๊ฐœ์„ ์œจ | +700% |
| ์ด ๊ฐœ์„  ํšŸ์ˆ˜ | 2,136ํšŒ |
| ํ•™์Šต ์‹œ๊ฐ„ | 7.7๋ถ„ |
### ์„ธ๋Œ€๋ณ„ ์„ฑ๋Šฅ
- **์ดˆ๊ธฐ (0-2000)**: ํฐ ๊ฐœ์„  (+20-30%/์„ธ๋Œ€)
- **์ค‘๊ธฐ (2000-5000)**: ์ค‘๊ฐ„ ๊ฐœ์„  (+10-15%/์„ธ๋Œ€)
- **ํ›„๊ธฐ (5000-10000)**: ๋ฏธ์„ธ ์กฐ์ • (+2-5%/์„ธ๋Œ€)
""")
test_suite_btn = gr.Button("๐Ÿ” ์ „์ฒด ํ…Œ์ŠคํŠธ ์Šค์œ„ํŠธ ์‹คํ–‰", variant="primary")
test_results = gr.Textbox(label="ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ", lines=15)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
load_btn.click(
lambda: tester.load_model("openfree/gpt2-bert"),
outputs=load_status
)
generate_btn.click(
tester.generate_text,
inputs=[prompt_input, max_length, temperature, top_p, rep_penalty],
outputs=[output_text, stats_text, gr.Textbox(visible=False)]
)
compare_btn.click(
tester.compare_with_parents,
inputs=[compare_prompt, compare_length],
outputs=comparison_output
)
task_btn.click(
lambda p: tester.generate_text(p, 100, 0.8, 0.9, 1.2),
inputs=task_prompt,
outputs=[task_output, gr.Textbox(visible=False), gr.Textbox(visible=False)]
)
def run_test_suite(progress=gr.Progress()):
"""์ „์ฒด ํ…Œ์ŠคํŠธ ์Šค์œ„ํŠธ ์‹คํ–‰"""
results = []
test_prompts = [
"The meaning of life is",
"import numpy as np\n",
"Scientists have discovered",
"def hello_world():",
"Breaking news:"
]
for i, prompt in enumerate(test_prompts):
progress((i+1)/len(test_prompts), desc=f"ํ…Œ์ŠคํŠธ {i+1}/{len(test_prompts)}")
try:
output, stats, _ = tester.generate_text(prompt, 30)
results.append(f"โœ… ํ”„๋กฌํ”„ํŠธ: {prompt[:30]}...\n ์ƒ์„ฑ ์„ฑ๊ณต")
except:
results.append(f"โŒ ํ”„๋กฌํ”„ํŠธ: {prompt[:30]}...\n ์ƒ์„ฑ ์‹คํŒจ")
return "\n".join(results)
test_suite_btn.click(
run_test_suite,
outputs=test_results
)
# ์‹คํ–‰
if __name__ == "__main__":
demo.launch(share=False)