qewrufda commited on
Commit
c155e6a
ยท
verified ยท
1 Parent(s): 76b3d37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -115
app.py CHANGED
@@ -1,45 +1,29 @@
1
- import os
2
  import torch
3
  import json
4
  import threading
5
- import time
6
- import faiss
7
- import numpy as np
8
- import gradio as gr
9
-
10
- from sentence_transformers import SentenceTransformer
11
  from transformers import (
12
- AutoModelForCausalLM,
13
  AutoTokenizer,
14
- TextIteratorStreamer,
 
15
  )
16
  from peft import PeftModel
17
- from huggingface_hub import login
 
 
18
 
19
  # ============================================================
20
- # 1. ๋กœ๊ทธ์ธ (Colab ์ „์šฉ ์ฝ”๋“œ ์ œ๊ฑฐ, Space ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์‚ฌ์šฉ)
21
  # ============================================================
22
- HF_TOKEN = os.environ.get("HF_TOKEN")
23
- if HF_TOKEN:
24
- login(token=HF_TOKEN)
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- print("Device:", device)
28
 
29
- # ============================================================
30
- # 2. ๋ชจ๋ธ ๋ฐ ๊ฒฝ๋กœ ์„ค์ •
31
- # ============================================================
32
- BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
33
- LORA_DIR = "./peft_lora" # Space์— ์—…๋กœ๋“œํ•œ LoRA ํด๋”
34
- DOC_PATH = "./rule.json" # Space์— ์—…๋กœ๋“œํ•œ rule.json
35
-
36
- print("Paths:")
37
- print("Model:", BASE_MODEL)
38
- print("LoRA:", LORA_DIR)
39
- print("Documents:", DOC_PATH)
40
 
41
  # ============================================================
42
- # 3. RAG ๋ฌธ์„œ ๋กœ๋“œ + ์ž„๋ฒ ๋”ฉ + FAISS ๊ตฌ์ถ•
43
  # ============================================================
44
  with open(DOC_PATH, "r", encoding="utf-8") as f:
45
  documents = json.load(f)
@@ -47,58 +31,42 @@ with open(DOC_PATH, "r", encoding="utf-8") as f:
47
  doc_texts = [d["text"] for d in documents]
48
 
49
  embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
 
50
 
51
- doc_embs = embedding_model.encode(
52
- doc_texts,
53
- convert_to_numpy=True,
54
- show_progress_bar=True
55
- ).astype("float32")
56
-
57
- dim = doc_embs.shape[1]
58
- index = faiss.IndexFlatL2(dim)
59
  index.add(doc_embs)
60
 
61
- print("FAISS index built. Total docs =", index.ntotal)
 
 
 
 
62
 
63
  # ============================================================
64
- # 4. LLM + LoRA ๋กœ๋“œ
65
  # ============================================================
 
66
  model = AutoModelForCausalLM.from_pretrained(
67
  BASE_MODEL,
68
- torch_dtype=torch.float16,
69
  device_map="auto",
 
70
  trust_remote_code=True
71
  )
72
-
73
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
74
-
75
- if tokenizer.pad_token is None:
76
- tokenizer.pad_token = tokenizer.eos_token
77
-
78
- print("Loading LoRA...")
79
  model = PeftModel.from_pretrained(
80
  model,
81
  LORA_DIR,
82
- torch_dtype=torch.float16,
83
  device_map="auto",
 
84
  )
85
  model.eval()
86
 
87
- print("Model + LoRA loaded successfully.")
88
-
89
- # ============================================================
90
- # 5. RAG ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
91
- # ============================================================
92
- def retrieve(query, k=3):
93
- q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
94
- D, I = index.search(q_emb, k)
95
- return [documents[i] for i in I[0]]
96
 
97
  # ============================================================
98
- # 6. ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ (๋„ค ์ฝ”๋“œ ๊ทธ๋Œ€๋กœ)
99
  # ============================================================
100
  def build_prompt(persona, instruction, query, retrieved_docs):
101
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
 
102
  return f"""
103
  ### ํŽ˜๋ฅด์†Œ๋‚˜:
104
  {persona}
@@ -115,73 +83,82 @@ def build_prompt(persona, instruction, query, retrieved_docs):
115
  ### ๋‹ต๋ณ€:
116
  """
117
 
 
118
  # ============================================================
119
- # 7. Streaming Chat ํ•จ์ˆ˜
120
  # ============================================================
121
- def stream_chat(persona, instruction, user_query, k=3, max_new_tokens=256):
 
 
122
 
123
- retrieved = retrieve(user_query, k=k)
124
- prompt = build_prompt(persona, instruction, user_query, retrieved)
125
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
126
 
127
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
128
- END_TOKENS = ["End of Answer", "###", "---", "โ€ป"]
129
-
130
- def run_generation():
131
- with torch.no_grad():
132
- model.generate(
133
- **inputs,
134
- max_new_tokens=max_new_tokens,
135
- do_sample=True,
136
- top_p=0.9,
137
- temperature=0.7,
138
- repetition_penalty=1.2,
139
- pad_token_id=tokenizer.pad_token_id,
140
- eos_token_id=tokenizer.eos_token_id,
141
- streamer=streamer,
142
- use_cache=True
143
- )
144
-
145
- thread = threading.Thread(target=run_generation)
146
  thread.start()
147
 
148
- full = ""
149
- for token in streamer:
150
- full += token
151
- yield token
 
152
 
153
  # ============================================================
154
- # 8. Gradio UI
155
  # ============================================================
156
- def gradio_answer(persona, query):
157
-
158
- instruction = ("""
159
- ๋‹น์‹ ์€ ํ•ด๋‹น ํŽ˜๋ฅด์†Œ๋‚˜์˜ ์„ฑ๊ฒฉ์„ ๊ฐ€์ง„ ์‹ฌํŒ๊ด€์ž…๋‹ˆ๋‹ค.
160
- ๋ฐ˜๋“œ์‹œ 3๋ฌธ์žฅ๋งŒ ๋งํ•˜์‹ญ์‹œ์˜ค.
161
- ๊ฐ ๋ฌธ์žฅ์€ 30์ž ์ด๋‚ด๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.
162
- ๊ทœ์ •์— ์šฐ์„ ์œผ๋กœ ๊ทผ๊ฑฐํ•˜์—ฌ ๋‹ตํ•˜์‹œ์˜ค.
163
- ํŒ๋‹จ ๊ทผ๊ฑฐ๊ฐ€ ํฌํ•จ๋œ ๋‹ต์•ˆ๋งŒ ์ƒ์„ฑํ•˜์‹œ์˜ค.
164
- ๊ฐ™์€ ๋ง์„ ๋ฐ˜๋ณตํ•˜๋Š” ๊ฒƒ์„ ์ ˆ๋Œ€ ๊ธˆํ•จ.
165
- ์˜๋ฌธ์— ํ™•์‹คํ•˜๊ฒŒ ์ž…์žฅ์„ ๋ฐํž ๊ฒƒ.
166
- ๋ฐ˜๋“œ์‹œ 3๋ฌธ์žฅ๋งŒ ๋งํ•˜์‹ญ์‹œ์˜ค.
167
- ๊ฐ ๋ฌธ์žฅ์€ 30์ž ์ด๋‚ด๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.
168
- """)
169
-
170
- return stream_chat(persona, instruction, query)
171
-
172
- with gr.Blocks() as app:
173
- gr.Markdown("## ๐Ÿ”ฅ KORMo LoRA + RAG Streaming Judge")
174
-
175
- persona = gr.Textbox(label="ํŽ˜๋ฅด์†Œ๋‚˜ ์ž…๋ ฅ")
176
- query = gr.Textbox(label="์งˆ๋ฌธ ์ž…๋ ฅ")
177
-
178
- output = gr.Textbox(label="์‘๋‹ต", lines=8)
179
- btn = gr.Button("์ƒ์„ฑ")
180
-
181
- btn.click(
182
- gradio_answer,
183
- inputs=[persona, query],
184
- outputs=output
185
- )
186
-
187
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
  import json
4
  import threading
 
 
 
 
 
 
5
  from transformers import (
 
6
  AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ TextIteratorStreamer
9
  )
10
  from peft import PeftModel
11
+ from sentence_transformers import SentenceTransformer
12
+ import faiss
13
+ import numpy as np
14
 
15
  # ============================================================
16
+ # 1. ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ •
17
  # ============================================================
18
+ BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
19
+ LORA_DIR = "kormo_lora_checkpoints/peft_lora"
20
+ DOC_PATH = "rule.json"
21
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # ============================================================
26
+ # 2. ๋ฌธ์„œ ๋กœ๋“œ + ์ž„๋ฒ ๋”ฉ + FAISS
27
  # ============================================================
28
  with open(DOC_PATH, "r", encoding="utf-8") as f:
29
  documents = json.load(f)
 
31
  doc_texts = [d["text"] for d in documents]
32
 
33
  embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
34
+ doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32")
35
 
36
+ index = faiss.IndexFlatL2(doc_embs.shape[1])
 
 
 
 
 
 
 
37
  index.add(doc_embs)
38
 
39
+ def retrieve(query, k=3):
40
+ q = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
41
+ D, I = index.search(q, k)
42
+ return [documents[i] for i in I[0]]
43
+
44
 
45
  # ============================================================
46
+ # 3. ๋ชจ๋ธ ๋กœ๋“œ
47
  # ============================================================
48
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
49
  model = AutoModelForCausalLM.from_pretrained(
50
  BASE_MODEL,
 
51
  device_map="auto",
52
+ torch_dtype=torch.float16,
53
  trust_remote_code=True
54
  )
 
 
 
 
 
 
 
55
  model = PeftModel.from_pretrained(
56
  model,
57
  LORA_DIR,
 
58
  device_map="auto",
59
+ torch_dtype=torch.float16,
60
  )
61
  model.eval()
62
 
 
 
 
 
 
 
 
 
 
63
 
64
  # ============================================================
65
+ # 4. ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
66
  # ============================================================
67
  def build_prompt(persona, instruction, query, retrieved_docs):
68
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
69
+
70
  return f"""
71
  ### ํŽ˜๋ฅด์†Œ๋‚˜:
72
  {persona}
 
83
  ### ๋‹ต๋ณ€:
84
  """
85
 
86
+
87
  # ============================================================
88
+ # 5. Streaming generator (Gradio ์šฉ)
89
  # ============================================================
90
+ def generate_stream(persona, instruction, query):
91
+ retrieved = retrieve(query)
92
+ prompt = build_prompt(persona, instruction, query, retrieved)
93
 
 
 
94
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
95
 
96
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
97
+
98
+ def run():
99
+ model.generate(
100
+ **inputs,
101
+ max_new_tokens=256,
102
+ do_sample=True,
103
+ top_p=0.9,
104
+ temperature=0.7,
105
+ repetition_penalty=1.2,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ eos_token_id=tokenizer.eos_token_id,
108
+ streamer=streamer
109
+ )
110
+
111
+ thread = threading.Thread(target=run)
 
 
 
112
  thread.start()
113
 
114
+ partial = ""
115
+ for text in streamer:
116
+ partial += text
117
+ yield partial
118
+
119
 
120
  # ============================================================
121
+ # 6. ํŽ˜๋ฅด์†Œ๋‚˜ 6๊ฐœ ์ž๋™ ์‹คํ–‰ ํ•จ์ˆ˜
122
  # ============================================================
123
+ persona_group = [
124
+ ("๋‹น์‹ ์€ ์›์น™์„ ์ง€ํ‚ค๋˜ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์œ ์—ฐํ•˜๊ฒŒ ํŒ๋‹จํ•˜๋Š” ์‹œ๊ฐ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. ๊ฐœ์ธ์˜ ๋Šฅ๋ ฅ๊ณผ ๊ธฐ์—ฌ๋„๋ฅผ ์ค‘์š”ํ•˜๊ฒŒ ์ƒ๊ฐํ•˜๋ฉฐ...", "๋ฐ•์„ธ์—ฐ"),
125
+ ("๋‹น์‹ ์€ ๊ณต์ •ํ•œ ๊ทœ์น™๊ณผ ์›์น™์„ ์ค‘์‹œํ•˜๋ฉด์„œ, ๊ฐœ์ธ์˜ ์„ฑ๊ณผ์™€ ๋Šฅ๋ ฅ์„ ์ธ์ •ํ•ด ์ฐจ๋“ฑ์„ ๋‘๊ณ  ๋ฐฐ๋ถ„ํ•ฉ๋‹ˆ๋‹ค...", "๊น€์ฐฝ์ค€"),
126
+ ("๊ทœ์œจ๊ณผ ์ž์œจ์˜ ๊ท ํ˜•์„ ์ง€ํ‚ค๋ฉฐ, ๋Šฅ๋ ฅ๊ณผ ์„ฑ๊ณผ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํŒ๋‹จํ•œ๋‹ค...", "์ด์ƒ๊ธฐ"),
127
+ ("๊ทœ์œจ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜์ง€๋งŒ ์œ ์—ฐํ•˜๋ฉฐ, ๋ถ„๋ฐฐ๋Š” ์ค‘๋ฆฝ์ ์ด๊ณ  ๊ฐœ์„ ์„ ์ถ”๊ตฌํ•œ๋‹ค...", "์ฑ„ํ›ˆ"),
128
+ ("์ž์œจ์„ ์กด์ค‘ํ•˜๋˜ ์ตœ์†Œํ•œ์˜ ๊ทœ์œจ์„ ์œ ์ง€ํ•˜๋ฉฐ, ๊ธฐ์—ฌ๋„์™€ ๊ฐœ์„ ์„ ๊ท ํ˜• ์žˆ๊ฒŒ ๋ฐ˜์˜ํ•œ๋‹ค...", "์šฉ์šฐ"),
129
+ ("๊ทœ์œจ๊ณผ ๊ณต์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์•ˆ์ •์ ์ธ ์šด์˜์„ ์ถ”๊ตฌํ•˜๋ฉฐ...", "ํ˜•์ง„")
130
+ ]
131
+
132
+ instruction_text = """
133
+ ๋‹น์‹ ์€ ํ•ด๋‹น ํŽ˜๋ฅด์†Œ๋‚˜์˜ ์„ฑ๊ฒฉ์„ ๊ฐ€์ง„ ์‹ฌํŒ๊ด€์ž…๋‹ˆ๋‹ค.
134
+ ๋ฐ˜๋“œ์‹œ 3๋ฌธ์žฅ๋งŒ ๋งํ•˜์‹ญ์‹œ์˜ค.
135
+ ๊ฐ ๋ฌธ์žฅ์€ 30์ž ์ด๋‚ด๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.
136
+ ๊ทœ์ •์— ์šฐ์„ ์œผ๋กœ ๊ทผ๊ฑฐํ•˜์—ฌ ๋‹ตํ•˜์‹œ์˜ค.
137
+ ๋ฐ˜๋ณต ๊ธˆ์ง€, ํŒ๋‹จ ๊ทผ๊ฑฐ ํ•„์ˆ˜.
138
+ """
139
+
140
+ def run_all_personas(query):
141
+ for persona, name in persona_group:
142
+ yield f"## ๐Ÿ‘ค {name}\n"
143
+ stream = generate_stream(persona, instruction_text, query)
144
+ for chunk in stream:
145
+ yield chunk
146
+ yield "\n\n---\n\n"
147
+
148
+
149
+ # ============================================================
150
+ # 7. Gradio UI
151
+ # ============================================================
152
+ with gr.Blocks() as demo:
153
+ gr.Markdown("# ๐Ÿ”ฅ KORMo 10B + LoRA Streaming Judge")
154
+
155
+ user_input = gr.Textbox(label="์งˆ๋ฌธ ์ž…๋ ฅ", value="3๋ฒˆ ์ด์ƒ ๊ฒฐ์„ํ–ˆ์ง€๋งŒ ์‹ค๋ ฅ์€ ๋›ฐ์–ด๋‚œ ์ •ํšŒ์›์„ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?")
156
+ output = gr.Markdown()
157
+
158
+ def start(query):
159
+ return run_all_personas(query)
160
+
161
+ run_btn = gr.Button("๐Ÿš€ ์‹คํ–‰ํ•˜๊ธฐ")
162
+ run_btn.click(start, inputs=user_input, outputs=output)
163
+
164
+ demo.launch()