qewrufda commited on
Commit
90e0b65
ยท
verified ยท
1 Parent(s): f205dde

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -76
app.py CHANGED
@@ -1,40 +1,32 @@
1
  import os
2
  import json
3
- import threading
4
  import torch
5
- import gradio as gr
6
  from huggingface_hub import login
7
  from sentence_transformers import SentenceTransformer
8
  import faiss
9
-
10
- from transformers import (
11
- AutoModelForCausalLM,
12
- AutoTokenizer,
13
- TextIteratorStreamer
14
- )
15
  from peft import PeftModel
16
-
17
 
18
  # ============================================================
19
- # 0. ํ™˜๊ฒฝ ์„ค์ •
20
  # ============================================================
21
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
22
- if HF_TOKEN:
23
- login(token=HF_TOKEN)
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
-
27
 
28
  # ============================================================
29
- # 1. ๊ฒฝ๋กœ ์„ค์ •
30
  # ============================================================
31
  BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
32
- LORA_DIR = "peft_lora"
33
- DOC_PATH = "rule.json"
34
-
35
 
36
  # ============================================================
37
- # 2. RAG ๋ฌธ์„œ ๋กœ๋“œ
38
  # ============================================================
39
  with open(DOC_PATH, "r", encoding="utf-8") as f:
40
  documents = json.load(f)
@@ -46,23 +38,19 @@ embedding_model = SentenceTransformer(
46
  device=device
47
  )
48
 
49
- import numpy as np
50
  doc_embs = embedding_model.encode(
51
- doc_texts,
52
- convert_to_numpy=True,
53
- show_progress_bar=True
54
  ).astype("float32")
55
 
56
  dim = doc_embs.shape[1]
57
  index = faiss.IndexFlatL2(dim)
58
  index.add(doc_embs)
59
 
 
60
 
61
  # ============================================================
62
- # 3. LLM + LoRA ๋กœ๋“œ
63
  # ============================================================
64
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
65
-
66
  model = AutoModelForCausalLM.from_pretrained(
67
  BASE_MODEL,
68
  torch_dtype=torch.float16,
@@ -70,27 +58,30 @@ model = AutoModelForCausalLM.from_pretrained(
70
  trust_remote_code=True
71
  )
72
 
 
 
 
 
73
  model = PeftModel.from_pretrained(
74
  model,
75
  LORA_DIR,
76
  torch_dtype=torch.float16,
77
- device_map="auto"
78
  )
79
 
 
80
  model.eval()
81
 
82
-
83
  # ============================================================
84
- # 4. RAG ๊ฒ€์ƒ‰
85
  # ============================================================
86
  def retrieve(query, k=3):
87
  q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
88
  D, I = index.search(q_emb, k)
89
  return [documents[i] for i in I[0]]
90
 
91
-
92
  # ============================================================
93
- # 5. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
94
  # ============================================================
95
  def build_prompt(persona, instruction, query, retrieved_docs):
96
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
@@ -110,21 +101,25 @@ def build_prompt(persona, instruction, query, retrieved_docs):
110
  ### ๋‹ต๋ณ€:
111
  """
112
 
113
-
114
  # ============================================================
115
- # 6. Streaming LLM (End of Answer ์ด์ „๊นŒ์ง€๋งŒ ์ถœ๋ ฅ)
116
  # ============================================================
117
- def stream_generate(prompt, max_new_tokens=256):
118
- streamer = TextIteratorStreamer(
119
- tokenizer,
120
- skip_prompt=True,
121
- skip_special_tokens=True
122
- )
123
 
124
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
125
 
126
- thread = threading.Thread(
127
- target=lambda:
 
 
 
 
 
 
 
128
  model.generate(
129
  **inputs,
130
  max_new_tokens=max_new_tokens,
@@ -134,51 +129,56 @@ def stream_generate(prompt, max_new_tokens=256):
134
  repetition_penalty=1.2,
135
  streamer=streamer
136
  )
137
- )
138
- thread.start()
139
 
140
- partial_text = ""
141
-
142
- for token in streamer:
143
- partial_text += token
144
-
145
- # End of Answer ๊ธฐ์ค€์œผ๋กœ ์ŠคํŠธ๋ฆผ ์ค‘๋‹จ
146
- if "End of Answer" in partial_text:
147
- partial_text = partial_text.split("End of Answer")[0]
148
- yield partial_text.strip()
149
- return
150
 
151
- yield partial_text
 
 
 
 
 
 
 
152
 
 
153
 
154
  # ============================================================
155
- # 7. Gradio ์ธํ„ฐํŽ˜์ด์Šค ํ•จ์ˆ˜
156
  # ============================================================
157
- def gradio_reply(persona, instruction, query):
158
- retrieved = retrieve(query, k=3)
159
- prompt = build_prompt(persona, instruction, query, retrieved)
160
- return stream_generate(prompt)
161
-
 
 
 
162
 
163
  # ============================================================
164
- # 8. Gradio UI
165
  # ============================================================
166
- with gr.Blocks() as demo:
167
-
168
- gr.Markdown("KORMo-10B + LoRA + RAG Streaming Demo (End-of-Answer Truncated)")
169
-
170
- persona = gr.Textbox(label="ํŽ˜๋ฅด์†Œ๋‚˜")
171
- instruction = gr.Textbox(label="๊ทœ์น™/์ง€์นจ")
172
- query = gr.Textbox(label="์งˆ๋ฌธ")
173
 
174
- output = gr.Textbox(label="์‘๋‹ต", lines=12)
175
 
176
- btn = gr.Button("Generate")
 
 
 
 
 
 
 
177
 
178
- btn.click(
179
- fn=gradio_reply,
180
- inputs=[persona, instruction, query],
181
- outputs=output
182
- )
183
 
184
- demo.launch()
 
 
 
 
 
1
  import os
2
  import json
 
3
  import torch
 
4
  from huggingface_hub import login
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
7
+ import numpy as np
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
9
  from peft import PeftModel
10
+ import threading
11
 
12
  # ============================================================
13
+ # 1. ํ™˜๊ฒฝ ์„ค์ • + ๋กœ๊ทธ์ธ
14
  # ============================================================
15
+ HF_TOKEN = os.getenv("HF_TOKEN") # โ† secret variable์—์„œ ๋ถˆ๋Ÿฌ์˜ด
16
+ login(token=HF_TOKEN)
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print("Device:", device)
20
 
21
  # ============================================================
22
+ # 2. ๊ฒฝ๋กœ ์„ค์ •
23
  # ============================================================
24
  BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
25
+ LORA_DIR = "./peft_lora" # ์„œ๋ฒ„ ๊ฒฝ๋กœ
26
+ DOC_PATH = "./rule.json" # ๋ฌธ์„œ ํŒŒ์ผ
 
27
 
28
  # ============================================================
29
+ # 3. RAG ๋ฌธ์„œ ๋กœ๋“œ + FAISS ๊ตฌ์ถ•
30
  # ============================================================
31
  with open(DOC_PATH, "r", encoding="utf-8") as f:
32
  documents = json.load(f)
 
38
  device=device
39
  )
40
 
 
41
  doc_embs = embedding_model.encode(
42
+ doc_texts, convert_to_numpy=True
 
 
43
  ).astype("float32")
44
 
45
  dim = doc_embs.shape[1]
46
  index = faiss.IndexFlatL2(dim)
47
  index.add(doc_embs)
48
 
49
+ print("FAISS index built:", index.ntotal)
50
 
51
  # ============================================================
52
+ # 4. LLM + LoRA ๋กœ๋“œ
53
  # ============================================================
 
 
54
  model = AutoModelForCausalLM.from_pretrained(
55
  BASE_MODEL,
56
  torch_dtype=torch.float16,
 
58
  trust_remote_code=True
59
  )
60
 
61
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
  model = PeftModel.from_pretrained(
66
  model,
67
  LORA_DIR,
68
  torch_dtype=torch.float16,
69
+ device_map="auto",
70
  )
71
 
72
+ model = model.to(device)
73
  model.eval()
74
 
 
75
  # ============================================================
76
+ # 5. RAG ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
77
  # ============================================================
78
  def retrieve(query, k=3):
79
  q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
80
  D, I = index.search(q_emb, k)
81
  return [documents[i] for i in I[0]]
82
 
 
83
  # ============================================================
84
+ # 6. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
85
  # ============================================================
86
  def build_prompt(persona, instruction, query, retrieved_docs):
87
  context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
 
101
  ### ๋‹ต๋ณ€:
102
  """
103
 
 
104
  # ============================================================
105
+ # 7. Streaming Chat
106
  # ============================================================
107
+ def stream_chat(persona, instruction, user_query, max_new_tokens=256):
108
+
109
+ retrieved = retrieve(user_query, k=3)
110
+ prompt = build_prompt(persona, instruction, user_query, retrieved)
 
 
111
 
112
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
113
 
114
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
115
+
116
+ END_TOKENS = [
117
+ "End of Answer", "### ๊ฒ€ํ†  ๊ฒฐ๊ณผ:", "### ์ตœ์ข… ๋‹ต๋ณ€",
118
+ "โ€ป", ">", "**๋‹ต๋ณ€**", "---", "###", "**"
119
+ ]
120
+
121
+ def run_gen():
122
+ with torch.no_grad():
123
  model.generate(
124
  **inputs,
125
  max_new_tokens=max_new_tokens,
 
129
  repetition_penalty=1.2,
130
  streamer=streamer
131
  )
 
 
132
 
133
+ thread = threading.Thread(target=run_gen)
134
+ thread.start()
 
 
 
 
 
 
 
 
135
 
136
+ full = ""
137
+ for tok in streamer:
138
+ print(tok, end="", flush=True)
139
+ full += tok
140
+ for e in END_TOKENS:
141
+ if e in full:
142
+ print()
143
+ return
144
 
145
+ print()
146
 
147
  # ============================================================
148
+ # 8. ํŽ˜๋ฅด์†Œ๋‚˜ ๋ชฉ๋ก
149
  # ============================================================
150
+ persona_group = [
151
+ ("๋‹น์‹ ์€ ์›์น™์„ ์ง€ํ‚ค๋˜ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์œ ์—ฐํ•˜๊ฒŒ ํŒ๋‹จํ•˜๋Š” ์‹œ๊ฐ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค...", "๋ฐ•์„ธ์—ฐ"),
152
+ ("๋‹น์‹ ์€ ๊ณต์ •ํ•œ ๊ทœ์น™๊ณผ ์›์น™์„ ์ค‘์‹œํ•˜๋ฉด์„œ, ๊ฐœ์ธ์˜ ์„ฑ๊ณผ์™€ ๋Šฅ๋ ฅ์„ ์ธ์ •ํ•ด ์ฐจ๋“ฑ...", "๊น€์ฐฝ์ค€"),
153
+ ("๊ทœ์œจ๊ณผ ์ž์œจ์˜ ๊ท ํ˜•์„ ์ง€ํ‚ค๋ฉฐ, ๋Šฅ๋ ฅ๊ณผ ์„ฑ๊ณผ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํŒ๋‹จํ•œ๋‹ค...", "์ด์ƒ๊ธฐ"),
154
+ ("๊ทœ์œจ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜์ง€๋งŒ ์œ ์—ฐํ•˜๋ฉฐ, ๋ถ„๋ฐฐ๋Š” ์ค‘๋ฆฝ์ ์ด๊ณ  ๊ฐœ์„ ์„ ์ถ”๊ตฌํ•œ๋‹ค...", "์ฑ„ํ›ˆ"),
155
+ ("์ž์œจ์„ ์กด์ค‘ํ•˜๋˜ ์ตœ์†Œํ•œ์˜ ๊ทœ์œจ์„ ์œ ์ง€ํ•˜๋ฉฐ, ๊ธฐ์—ฌ๋„์™€ ๊ฐœ์„ ์„ ๊ท ํ˜• ์žˆ๊ฒŒ ๋ฐ˜์˜...", "์šฉ์šฐ"),
156
+ ("๊ทœ์œจ๊ณผ ๊ณต์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์•ˆ์ •์ ์ธ ์šด์˜์„ ์ถ”๊ตฌํ•˜๋ฉฐ, ๊ท ๋“ฑยท๊ฐœ์„ ยท์นœ๋ชฉ ๊ฐ„์˜ ๊ท ํ˜•...", "ํ˜•์ง„")
157
+ ]
158
 
159
  # ============================================================
160
+ # 9. ํ”„๋กœ๊ทธ๋žจ ์‹คํ–‰ (์ž…๋ ฅ ๋ฐ›๋Š” ๋ถ€๋ถ„)
161
  # ============================================================
162
+ if __name__ == "__main__":
 
 
 
 
 
 
163
 
164
+ query = input("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”: ")
165
 
166
+ instruction = """
167
+ ๋‹น์‹ ์€ ํ•ด๋‹น ํŽ˜๋ฅด์†Œ๋‚˜์˜ ์„ฑ๊ฒฉ์„ ๊ฐ€์ง„ ์‹ฌํŒ๊ด€์ž…๋‹ˆ๋‹ค.
168
+ ๋ฐ˜๋“œ์‹œ 3๋ฌธ์žฅ๋งŒ ๋งํ•˜์‹ญ์‹œ์˜ค.
169
+ ๊ฐ ๋ฌธ์žฅ์€ 30์ž ์ด๋‚ด.
170
+ ๊ทœ์ •์„ ์šฐ์„ ํ•˜์—ฌ ๋‹ตํ•˜์„ธ์š”.
171
+ ํŒ๋‹จ ๊ทผ๊ฑฐ ํฌํ•จ.
172
+ ๋ฐ˜๋ณต ๊ธˆ์ง€.
173
+ """
174
 
175
+ for persona_text, persona_name in persona_group:
176
+ print("\n====================")
177
+ print(f"### {persona_name} ###")
178
+ print("====================")
 
179
 
180
+ stream_chat(
181
+ persona=persona_text,
182
+ instruction=instruction,
183
+ user_query=query
184
+ )