epinfomax commited on
Commit
190ca83
ยท
verified ยท
1 Parent(s): f88ea6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -20
app.py CHANGED
@@ -1,31 +1,79 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
4
 
5
- model_id = "epinfomax/BizFlow-Summarizer-Ko"
 
 
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
- torch_dtype=torch.float16,
11
- device_map="auto"
12
- )
 
 
 
 
 
 
13
 
14
  def summarize(text):
15
- prompt = f"๋‹ค์Œ ๊ธ€์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š”:\n\n{text}"
16
- messages = [{"role": "user", "content": prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
19
- outputs = model.generate(input_ids, max_new_tokens=512, do_sample=True, temperature=0.7)
20
- response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
21
- return response
22
 
23
- demo = gr.Interface(
 
24
  fn=summarize,
25
- inputs=gr.Textbox(lines=10, label="์›๋ฌธ", placeholder="์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”..."),
26
- outputs=gr.Textbox(lines=5, label="์š”์•ฝ"),
27
- title="BizFlow Summarizer Ko",
28
- description="ํ•œ๊ตญ์–ด ๋‰ด์Šค/๋ฌธ์„œ ์š”์•ฝ ๋ชจ๋ธ"
 
 
 
 
 
29
  )
30
 
31
- demo.launch()
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
 
6
+ # 1. ๋ชจ๋ธ ID ์„ค์ •
7
+ base_id = "Qwen/Qwen2.5-7B-Instruct"
8
+ adapter_id = "epinfomax/BizFlow-Summarizer-Ko"
9
 
10
+ # 2. ํ•˜๋“œ์›จ์–ด ์„ค์ • (GPU/CPU ์ž๋™ ๊ฐ์ง€)
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ dtype = torch.float16 if device == "cuda" else torch.float32
13
+
14
+ print(f"๐Ÿš€ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... (Device: {device})")
15
+
16
+ # 3. ๋ชจ๋ธ ๋กœ๋“œ
17
+ tokenizer = AutoTokenizer.from_pretrained(base_id)
18
+ model = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=dtype)
19
+ model = PeftModel.from_pretrained(model, adapter_id)
20
+ model.to(device)
21
+ model.eval()
22
 
23
  def summarize(text):
24
+ # โ˜… ์ˆ˜์ •๋œ ๋ถ€๋ถ„: messages ๋ฆฌ์ŠคํŠธ ๊ตฌ์กฐํ™”
25
+ # ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ(์ง€์‹œ์‚ฌํ•ญ)์™€ ์‚ฌ์šฉ์ž ์ž…๋ ฅ(text)์„ ๋”•์…”๋„ˆ๋ฆฌ ๋ฆฌ์ŠคํŠธ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
26
+ messages = [
27
+ {
28
+ "role": "system",
29
+ "content": "๋‹น์‹ ์€ ๋น„์ฆˆ๋‹ˆ์Šค ๋ฌธ์„œ๋ฅผ ์ „๋ฌธ์ ์œผ๋กœ ์š”์•ฝํ•˜๋Š” AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค. ํ•ต์‹ฌ ๋‚ด์šฉ์„ ๋ช…ํ™•ํ•˜๊ฒŒ ์š”์•ฝํ•ด ์ฃผ์„ธ์š”."
30
+ },
31
+ {
32
+ "role": "user",
33
+ "content": text
34
+ }
35
+ ]
36
+
37
+ # ์ž…๋ ฅ ํ…์ŠคํŠธ ํฌ๋งทํŒ… (Chat Template ์ ์šฉ)
38
+ input_text = tokenizer.apply_chat_template(
39
+ messages,
40
+ tokenize=False,
41
+ add_generation_prompt=True
42
+ )
43
+
44
+ # ํ† ํฌ๋‚˜์ด์ง• ๋ฐ GPU ์ด๋™
45
+ inputs = tokenizer([input_text], return_tensors="pt").to(device)
46
+
47
+ # ์ถ”๋ก 
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ **inputs,
51
+ max_new_tokens=512,
52
+ temperature=0.3,
53
+ repetition_penalty=1.1
54
+ )
55
+
56
+ # ๊ฒฐ๊ณผ ๋””์ฝ”๋”ฉ (์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ์ œ์™ธ)
57
+ generated_tokens = outputs[:, inputs.input_ids.shape[1]:]
58
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
59
 
60
+ # batch_decode๋Š” ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋ฏ€๋กœ ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๋งŒ ๋ฐ˜ํ™˜ํ•˜์—ฌ ๊น”๋”ํ•˜๊ฒŒ ์ถœ๋ ฅ
61
+ return result[0]
 
 
62
 
63
+ # 4. ์›น ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
64
+ iface = gr.Interface(
65
  fn=summarize,
66
+ inputs=gr.Textbox(
67
+ lines=15,
68
+ placeholder="์š”์•ฝํ•  ๋ฌธ์„œ๋ฅผ ์—ฌ๊ธฐ์— ๋ถ™์—ฌ๋„ฃ์œผ์„ธ์š”...",
69
+ label="์ž…๋ ฅ ๋ฌธ์„œ"
70
+ ),
71
+ outputs=gr.Textbox(label="์š”์•ฝ ๊ฒฐ๊ณผ"),
72
+ title="BizFlow ๋ฌธ์„œ ์š”์•ฝ๊ธฐ",
73
+ description="Qwen2.5-7B + ํŒŒ์ธํŠœ๋‹(LoRA) ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.",
74
+ examples=[["์‚ผ์„ฑ์ „์ž๊ฐ€ ์˜ค๋Š˜ ์ปจํผ๋Ÿฐ์Šค์ฝœ์„ ํ†ตํ•ด ์ง€๋‚œํ•ด 4๋ถ„๊ธฐ ํ™•์ • ์‹ค์ ์„ ๋ฐœํ‘œํ–ˆ๋‹ค. ์—ฐ๊ฒฐ ๊ธฐ์ค€ ๋งค์ถœ์€ 67์กฐ 7800์–ต ์›์œผ๋กœ ์ „๋…„ ๋™๊ธฐ ๋Œ€๋น„ 3.8% ๊ฐ์†Œํ–ˆ์œผ๋‚˜, ์˜์—…์ด์ต์€ 2์กฐ 8200์–ต ์›์œผ๋กœ..."]]
75
  )
76
 
77
+ # ์•ฑ ์‹คํ–‰
78
+ if __name__ == "__main__":
79
+ iface.launch()