qewrufda commited on
Commit
a41052f
ยท
verified ยท
1 Parent(s): d26cf1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -70
app.py CHANGED
@@ -1,35 +1,42 @@
1
- import gradio as gr
2
- import torch
3
  import json
4
  import threading
 
 
 
 
 
 
5
  from transformers import (
6
  AutoTokenizer,
7
  AutoModelForCausalLM,
8
- TextIteratorStreamer
9
  )
10
  from peft import PeftModel
11
- from sentence_transformers import SentenceTransformer
12
- import faiss
13
- import numpy as np
14
-
15
- # ============================================================
16
- # 1. ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ •
17
- # ============================================================
18
- BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
19
- LORA_DIR = "peft_lora"
20
- DOC_PATH = "rule.json"
21
 
 
 
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
-
24
-
25
- # ============================================================
26
- # 2. ๋ฌธ์„œ ๋กœ๋“œ + ์ž„๋ฒ ๋”ฉ + FAISS
27
- # ============================================================
 
 
 
 
 
 
 
28
  with open(DOC_PATH, "r", encoding="utf-8") as f:
29
  documents = json.load(f)
30
 
31
  doc_texts = [d["text"] for d in documents]
32
 
 
33
  embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
34
  doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32")
35
 
@@ -41,17 +48,24 @@ def retrieve(query, k=3):
41
  D, I = index.search(q, k)
42
  return [documents[i] for i in I[0]]
43
 
 
44
 
45
- # ============================================================
46
- # 3. ๋ชจ๋ธ ๋กœ๋“œ
47
- # ============================================================
48
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
 
 
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
  BASE_MODEL,
51
  device_map="auto",
52
  torch_dtype=torch.float16,
53
- trust_remote_code=True
54
  )
 
 
55
  model = PeftModel.from_pretrained(
56
  model,
57
  LORA_DIR,
@@ -59,14 +73,13 @@ model = PeftModel.from_pretrained(
59
  torch_dtype=torch.float16,
60
  )
61
  model.eval()
 
62
 
63
-
64
- # ============================================================
65
- # 4. ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
66
- # ============================================================
67
  def build_prompt(persona, instruction, query, retrieved_docs):
68
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
69
-
70
  return f"""
71
  ### ํŽ˜๋ฅด์†Œ๋‚˜:
72
  {persona}
@@ -83,43 +96,68 @@ def build_prompt(persona, instruction, query, retrieved_docs):
83
  ### ๋‹ต๋ณ€:
84
  """
85
 
86
-
87
- # ============================================================
88
- # 5. Streaming generator (Gradio ์šฉ)
89
- # ============================================================
90
- def generate_stream(persona, instruction, query):
91
- retrieved = retrieve(query)
92
  prompt = build_prompt(persona, instruction, query, retrieved)
93
-
94
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
95
 
96
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
97
 
98
- def run():
99
- model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  **inputs,
101
- max_new_tokens=256,
102
  do_sample=True,
103
  top_p=0.9,
104
  temperature=0.7,
105
  repetition_penalty=1.2,
106
- pad_token_id=tokenizer.eos_token_id,
107
  eos_token_id=tokenizer.eos_token_id,
108
- streamer=streamer
109
  )
 
 
 
110
 
111
- thread = threading.Thread(target=run)
112
- thread.start()
113
-
114
- partial = ""
115
- for text in streamer:
116
- partial += text
117
- yield partial
118
-
119
-
120
- # ============================================================
121
- # 6. ํŽ˜๋ฅด์†Œ๋‚˜ 6๊ฐœ ์ž๋™ ์‹คํ–‰ ํ•จ์ˆ˜
122
- # ============================================================
123
  persona_group = [
124
  ("๋‹น์‹ ์€ ์›์น™์„ ์ง€ํ‚ค๋˜ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์œ ์—ฐํ•˜๊ฒŒ ํŒ๋‹จํ•˜๋Š” ์‹œ๊ฐ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. ๊ฐœ์ธ์˜ ๋Šฅ๋ ฅ๊ณผ ๊ธฐ์—ฌ๋„๋ฅผ ์ค‘์š”ํ•˜๊ฒŒ ์ƒ๊ฐํ•˜๋ฉฐ...", "๋ฐ•์„ธ์—ฐ"),
125
  ("๋‹น์‹ ์€ ๊ณต์ •ํ•œ ๊ทœ์น™๊ณผ ์›์น™์„ ์ค‘์‹œํ•˜๋ฉด์„œ, ๊ฐœ์ธ์˜ ์„ฑ๊ณผ์™€ ๋Šฅ๋ ฅ์„ ์ธ์ •ํ•ด ์ฐจ๋“ฑ์„ ๋‘๊ณ  ๋ฐฐ๋ถ„ํ•ฉ๋‹ˆ๋‹ค...", "๊น€์ฐฝ์ค€"),
@@ -137,29 +175,51 @@ instruction_text = """
137
  ๋ฐ˜๋ณต ๊ธˆ์ง€, ํŒ๋‹จ ๊ทผ๊ฑฐ ํ•„์ˆ˜.
138
  """
139
 
140
- def run_all_personas(query):
 
 
 
 
141
  for persona, name in persona_group:
142
- yield f"## ๐Ÿ‘ค {name}\n"
143
- stream = generate_stream(persona, instruction_text, query)
144
- for chunk in stream:
145
- yield chunk
 
146
  yield "\n\n---\n\n"
147
 
148
-
149
- # ============================================================
150
- # 7. Gradio UI
151
- # ============================================================
 
 
 
 
 
 
 
 
 
 
152
  with gr.Blocks() as demo:
153
- gr.Markdown("# ๐Ÿ”ฅ KORMo 10B + LoRA Streaming Judge")
 
 
 
 
 
154
 
155
- user_input = gr.Textbox(label="์งˆ๋ฌธ ์ž…๋ ฅ", value="3๋ฒˆ ์ด์ƒ ๊ฒฐ์„ํ–ˆ์ง€๋งŒ ์‹ค๋ ฅ์€ ๋›ฐ์–ด๋‚œ ์ •ํšŒ์›์„ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?")
156
- output = gr.Markdown()
 
 
157
 
158
- def start(query):
159
- for chunk in run_all_personas(query):
160
- yield chunk
161
-
162
- run_btn = gr.Button("๐Ÿš€ ์‹คํ–‰ํ•˜๊ธฐ")
163
- run_btn.click(start, inputs=user_input, outputs=output)
164
 
 
165
  demo.launch()
 
1
+ # app.py
2
+ import os
3
  import json
4
  import threading
5
+ import gradio as gr
6
+ import torch
7
+ import faiss
8
+ import numpy as np
9
+
10
+ from sentence_transformers import SentenceTransformer
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModelForCausalLM,
14
+ TextIteratorStreamer,
15
  )
16
  from peft import PeftModel
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # -----------------------------
19
+ # 0. ํ™˜๊ฒฝ ๊ฒ€์‚ฌ
20
+ # -----------------------------
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print("Device:", device)
23
+
24
+ # -----------------------------
25
+ # 1. ๋ชจ๋ธ / ๊ฒฝ๋กœ ์„ค์ •
26
+ # -----------------------------
27
+ BASE_MODEL = "KORMo-Team/KORMo-10B-sft" # ์˜ˆ์‹œ
28
+ LORA_DIR = "peft_lora" # Space ๋‚ด ์—…๋กœ๋“œ๋œ LoRA ํด๋”(๋˜๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ)
29
+ DOC_PATH = "rule.json" # Space ๋‚ด ์—…๋กœ๋“œ๋œ ๊ทœ์ • JSON
30
+
31
+ # -----------------------------
32
+ # 2. RAG ๋ฌธ์„œ ๋กœ๋“œ + FAISS ์ค€๋น„
33
+ # -----------------------------
34
  with open(DOC_PATH, "r", encoding="utf-8") as f:
35
  documents = json.load(f)
36
 
37
  doc_texts = [d["text"] for d in documents]
38
 
39
+ # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (ํ•œ๊ตญ์–ด)
40
  embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
41
  doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32")
42
 
 
48
  D, I = index.search(q, k)
49
  return [documents[i] for i in I[0]]
50
 
51
+ print("FAISS ready, docs:", index.ntotal)
52
 
53
+ # -----------------------------
54
+ # 3. ํ† ํฌ๋‚˜์ด์ €ยท๋ชจ๋ธ ๋กœ๋“œ (LoRA ํฌํ•จ)
55
+ # -----------------------------
56
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
57
+ if tokenizer.pad_token is None:
58
+ tokenizer.pad_token = tokenizer.eos_token
59
+
60
+ # ๋ณธ์ฒด ๋ชจ๋ธ (device_map="auto" ์‚ฌ์šฉํ•˜๋ฉด accelerate๊ฐ€ ์ž๋™ ๋ถ„๋ฐฐ)
61
  model = AutoModelForCausalLM.from_pretrained(
62
  BASE_MODEL,
63
  device_map="auto",
64
  torch_dtype=torch.float16,
65
+ trust_remote_code=True,
66
  )
67
+
68
+ # LoRA (PEFT) ์ ์šฉ
69
  model = PeftModel.from_pretrained(
70
  model,
71
  LORA_DIR,
 
73
  torch_dtype=torch.float16,
74
  )
75
  model.eval()
76
+ print("Model + LoRA loaded")
77
 
78
+ # -----------------------------
79
+ # 4. ํ”„๋กฌํ”„ํŠธ ๋นŒ๋” (์›๋ณธ ๊ทธ๋Œ€๋กœ)
80
+ # -----------------------------
 
81
  def build_prompt(persona, instruction, query, retrieved_docs):
82
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
 
83
  return f"""
84
  ### ํŽ˜๋ฅด์†Œ๋‚˜:
85
  {persona}
 
96
  ### ๋‹ต๋ณ€:
97
  """
98
 
99
+ # -----------------------------
100
+ # 5. ์ŠคํŠธ๋ฆฌ๋ฐ generator (UI์šฉ)
101
+ # - TextIteratorStreamer + ์Šค๋ ˆ๋“œ ๋ฐฉ์‹
102
+ # -----------------------------
103
+ def generate_stream(persona, instruction, query, max_new_tokens=256):
104
+ retrieved = retrieve(query, k=3)
105
  prompt = build_prompt(persona, instruction, query, retrieved)
 
106
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
107
 
108
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
109
 
110
+ def run_generate():
111
+ with torch.no_grad():
112
+ model.generate(
113
+ **inputs,
114
+ max_new_tokens=max_new_tokens,
115
+ do_sample=True,
116
+ top_p=0.9,
117
+ temperature=0.7,
118
+ repetition_penalty=1.2,
119
+ pad_token_id=tokenizer.pad_token_id,
120
+ eos_token_id=tokenizer.eos_token_id,
121
+ streamer=streamer,
122
+ use_cache=True
123
+ )
124
+
125
+ thread = threading.Thread(target=run_generate)
126
+ thread.start()
127
+
128
+ accumulated = ""
129
+ for token in streamer:
130
+ accumulated += token
131
+ yield accumulated # Gradio์˜ ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ์€ ๋ถ€๋ถ„ ๋ฌธ์ž์—ด์„ ๊ณ„์† ๋ฐ›๊ฒŒ ํ•จ
132
+
133
+ # -----------------------------
134
+ # 6. ๋™๊ธฐ ์ƒ์„ฑ (API์šฉ) โ€” ์ „์ฒด ํ…์Šค๏ฟฝ๏ฟฝ ๋ฐ˜ํ™˜
135
+ # - model.generate๋ฅผ ๋ธ”๋กํ‚น์œผ๋กœ ์‹คํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋“œ
136
+ # -----------------------------
137
+ def generate_once(persona, instruction, query, max_new_tokens=256):
138
+ retrieved = retrieve(query, k=3)
139
+ prompt = build_prompt(persona, instruction, query, retrieved)
140
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
141
+
142
+ with torch.no_grad():
143
+ outputs = model.generate(
144
  **inputs,
145
+ max_new_tokens=max_new_tokens,
146
  do_sample=True,
147
  top_p=0.9,
148
  temperature=0.7,
149
  repetition_penalty=1.2,
150
+ pad_token_id=tokenizer.pad_token_id,
151
  eos_token_id=tokenizer.eos_token_id,
152
+ use_cache=True
153
  )
154
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
155
+ # prompt ํฌํ•จ๋œ ๊ฒฝ์šฐ ์ œ๊ฑฐ
156
+ return text.replace(prompt, "").strip()
157
 
158
+ # -----------------------------
159
+ # 7. ํŽ˜๋ฅด์†Œ๋‚˜ ๊ทธ๋ฃน (์›๋ณธ ์œ ์ง€)
160
+ # -----------------------------
 
 
 
 
 
 
 
 
 
161
  persona_group = [
162
  ("๋‹น์‹ ์€ ์›์น™์„ ์ง€ํ‚ค๋˜ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์œ ์—ฐํ•˜๊ฒŒ ํŒ๋‹จํ•˜๋Š” ์‹œ๊ฐ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. ๊ฐœ์ธ์˜ ๋Šฅ๋ ฅ๊ณผ ๊ธฐ์—ฌ๋„๋ฅผ ์ค‘์š”ํ•˜๊ฒŒ ์ƒ๊ฐํ•˜๋ฉฐ...", "๋ฐ•์„ธ์—ฐ"),
163
  ("๋‹น์‹ ์€ ๊ณต์ •ํ•œ ๊ทœ์น™๊ณผ ์›์น™์„ ์ค‘์‹œํ•˜๋ฉด์„œ, ๊ฐœ์ธ์˜ ์„ฑ๊ณผ์™€ ๋Šฅ๋ ฅ์„ ์ธ์ •ํ•ด ์ฐจ๋“ฑ์„ ๋‘๊ณ  ๋ฐฐ๋ถ„ํ•ฉ๋‹ˆ๋‹ค...", "๊น€์ฐฝ์ค€"),
 
175
  ๋ฐ˜๋ณต ๊ธˆ์ง€, ํŒ๋‹จ ๊ทผ๊ฑฐ ํ•„์ˆ˜.
176
  """
177
 
178
+ # -----------------------------
179
+ # 8. UI์šฉ: ๋ชจ๋“  ํŽ˜๋ฅด์†Œ๋‚˜์— ๋Œ€ํ•ด ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ (Gradio Blocks)
180
+ # -----------------------------
181
+ def run_all_streaming(query):
182
+ # Gradio์— ๋ฌธ์ž์—ด์„ ๋ถ€๋ถ„์ ์œผ๋กœ ๋ณด์—ฌ์ฃผ๊ณ  ์‹ถ์„ ๋•Œ yield๋ฅผ ์‚ฌ์šฉ
183
  for persona, name in persona_group:
184
+ header = f"## ๐Ÿ‘ค {name}\n"
185
+ yield header # persona header
186
+ # streaming generator yields partials; ๊ทธ๊ฑธ ๊ทธ๋Œ€๋กœ UI๋กœ ๋ณด๋ƒ„
187
+ for partial in generate_stream(persona, instruction_text, query):
188
+ yield partial
189
  yield "\n\n---\n\n"
190
 
191
+ # -----------------------------
192
+ # 9. API์šฉ: ๋ชจ๋“  ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ๋™๊ธฐ์ ์œผ๋กœ ์‹คํ–‰ํ•˜๊ณ  ํ•˜๋‚˜์˜ ๋ฌธ์ž์—ด๋กœ ๋ฐ˜ํ™˜
193
+ # -----------------------------
194
+ def run_all_api(query):
195
+ out = ""
196
+ for persona, name in persona_group:
197
+ out += f"## ๐Ÿ‘ค {name}\n"
198
+ text = generate_once(persona, instruction_text, query)
199
+ out += text + "\n\n---\n\n"
200
+ return out
201
+
202
+ # -----------------------------
203
+ # 10. Gradio ์•ฑ ๊ตฌ์„ฑ
204
+ # -----------------------------
205
  with gr.Blocks() as demo:
206
+ gr.Markdown("# ๐Ÿ”ฅ KORMo LoRA + RAG (Streaming UI + API)")
207
+ user_input = gr.Textbox(label="์งˆ๋ฌธ ์ž…๋ ฅ", value="3๋ฒˆ ์ด์ƒ์˜ ๊ฒฐ์„์„ ํ–ˆ์ง€๋งŒ ์‹ค๋ ฅ์€ ๋™์•„๋ฆฌ์—์„œ ๋›ฐ์–ด๋‚œ ์ •ํšŒ์›์„ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?")
208
+ output_stream = gr.Markdown() # streaming UI์—์„  Markdown์œผ๋กœ ์‹ค์‹œ๊ฐ„ ๊ฐฑ์‹ ์ด ๊น”๋”ํ•จ
209
+
210
+ run_btn = gr.Button("๐Ÿš€ ์‹คํ–‰(Streaming UI)")
211
+ run_btn.click(fn=run_all_streaming, inputs=[user_input], outputs=[output_stream])
212
 
213
+ # API์šฉ ๋ฒ„ํŠผ (๋น„์ฃผ์–ผ์šฉ; ์‹ค์ œ API๋Š” ์•„๋ž˜์— api_name์œผ๋กœ ๋“ฑ๋ก)
214
+ run_btn_api = gr.Button("๐Ÿ” ์‹คํ–‰(API, ๋™๊ธฐ)")
215
+ api_output = gr.Textbox(label="API ๋ฐ˜ํ™˜ ๊ฒฐ๊ณผ", lines=10)
216
+ run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output])
217
 
218
+ # ์ค‘์š”: gradio_client๋กœ ํ˜ธ์ถœํ•  API ์ด๋ฆ„์„ ์ง€์ • (๋ฒ„ํŠผ ์ด๋ฒคํŠธ์— api_name).
219
+ # API ์—”๋“œํฌ์ธํŠธ ์ด๋ฆ„์€ "start_api"๊ฐ€ ๋จ.
220
+ # (์•„๋ž˜ ์ถ”๊ฐ€๋กœ ๋™์ผ ํ•จ์ˆ˜๋ฅผ ๋ณ„๋„๋กœ api ์—”๋“œํฌ์ธํŠธ์— ์—ฐ๊ฒฐํ•ด๋„ ๋จ.)
221
+ # ์—ฌ๊ธฐ์„œ๋Š” ํด๋ฆญ ํ•ธ๋“ค๋Ÿฌ์— api_name์„ ์„ค์ •ํ•˜๋ ค๋ฉด ์ด๋ ‡๊ฒŒ๋„ ๊ฐ€๋Šฅ:
222
+ # run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output], api_name="start_api")
 
223
 
224
+ # Launch - Space์—์„œ๋Š” ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์ž˜ ๋™์ž‘ํ•จ
225
  demo.launch()