qewrufda commited on
Commit
cabbb68
Β·
verified Β·
1 Parent(s): 84c6896

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -70
app.py CHANGED
@@ -1,85 +1,188 @@
 
1
  import torch
 
 
 
 
 
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # ======================
6
- # 1. λͺ¨λΈ λ‘œλ”©
7
- # ======================
8
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # μ›ν•˜λŠ” λͺ¨λΈλ‘œ λ³€κ²½ κ°€λŠ₯
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_NAME,
16
- torch_dtype=torch.float16, # μ–‘μžν™” μ•„λ‹˜
17
- device_map="auto"
 
18
  )
19
 
20
- # ======================
21
- # 2. 슀트리밍 ν•¨μˆ˜
22
- # ======================
23
- def generate_stream(prompt):
24
- """
25
- 이 ν•¨μˆ˜λŠ” Gradioμ—μ„œ 슀트리밍이 λ˜λ„λ‘
26
- yield 둜 토큰 λ‹¨μœ„ 좜λ ₯ν•˜λŠ” μ œλ„ˆλ ˆμ΄ν„°μž…λ‹ˆλ‹€.
27
- """
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
30
 
31
- # 슀트리밍 생성기 호좜
32
- streamer = tokenizer.decode
33
-
34
- # λͺ¨λΈμ˜ generate()λŠ” streaming 지원 X β†’ μ•„λž˜λŠ” μˆ˜λ™ 슀트리밍 κ΅¬ν˜„
35
- # (HuggingFace TextStreamer λŒ€μ‹  토큰 λ‹¨μœ„ μˆ˜λ™ 처리)
36
-
37
- generated = inputs["input_ids"]
38
- past_key_values = None
39
-
40
- for _ in range(512): # μ΅œλŒ€ 생성 토큰 수
41
- outputs = model(
42
- input_ids=generated if past_key_values is None else generated[:, -1:],
43
- past_key_values=past_key_values,
44
- use_cache=True
45
- )
46
- logits = outputs.logits[:, -1, :]
47
- past_key_values = outputs.past_key_values
48
-
49
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
50
- generated = torch.cat([generated, next_token], dim=-1)
51
-
52
- # λ””μ½”λ“œ ν›„ μ‚¬μš©μžμ—κ²Œ 슀트리밍 전달
53
- text = tokenizer.decode(generated[0], skip_special_tokens=True)
54
- yield text
55
-
56
- # μ’…λ£Œ 토큰 발견 μ‹œ stop
57
- if next_token.item() in tokenizer.eos_token_id if isinstance(tokenizer.eos_token_id, list) else [tokenizer.eos_token_id]:
58
- break
59
-
60
-
61
- # ======================
62
- # 3. Gradio UI
63
- # ======================
64
- with gr.Blocks() as demo:
65
- gr.Markdown("# πŸš€ Custom LLM Streaming Demo (No Quantization)")
66
-
67
- with gr.Row():
68
- input_box = gr.Textbox(label="Prompt μž…λ ₯", lines=4)
69
-
70
- output_box = gr.Textbox(label="응닡 (Streaming)")
71
-
72
- generate_button = gr.Button("생성")
73
-
74
- # λ²„νŠΌ ν΄λ¦­μ‹œ 슀트리밍 μ—°κ²°
75
- generate_button.click(
76
- fn=generate_stream,
77
- inputs=input_box,
78
- outputs=output_box
 
 
 
 
 
 
 
 
 
 
79
  )
80
 
81
- # ======================
82
- # 4. μ‹€ν–‰
83
- # ======================
84
- if __name__ == "__main__":
85
- demo.launch()
 
1
+ import os
2
  import torch
3
+ import json
4
+ import threading
5
+ import time
6
+ import faiss
7
+ import numpy as np
8
  import gradio as gr
 
9
 
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import (
12
+ AutoModelForCausalLM,
13
+ AutoTokenizer,
14
+ TextIteratorStreamer,
15
+ )
16
+ from peft import PeftModel
17
+ from huggingface_hub import login
18
+
19
+ # ============================================================
20
+ # 1. 둜그인 (Colab μ „μš© μ½”λ“œ 제거, Space ν™˜κ²½ λ³€μˆ˜ μ‚¬μš©)
21
+ # ============================================================
22
+ HF_TOKEN = os.environ.get("HF_TOKEN")
23
+ if HF_TOKEN:
24
+ login(token=HF_TOKEN)
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print("Device:", device)
28
+
29
+ # ============================================================
30
+ # 2. λͺ¨λΈ 및 경둜 μ„€μ •
31
+ # ============================================================
32
+ BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
33
+ LORA_DIR = "./kormo_lora" # Space에 μ—…λ‘œλ“œν•œ LoRA 폴더
34
+ DOC_PATH = "./rule.json" # Space에 μ—…λ‘œλ“œν•œ rule.json
35
+
36
+ print("Paths:")
37
+ print("Model:", BASE_MODEL)
38
+ print("LoRA:", LORA_DIR)
39
+ print("Documents:", DOC_PATH)
40
+
41
+ # ============================================================
42
+ # 3. RAG λ¬Έμ„œ λ‘œλ“œ + μž„λ² λ”© + FAISS ꡬ좕
43
+ # ============================================================
44
+ with open(DOC_PATH, "r", encoding="utf-8") as f:
45
+ documents = json.load(f)
46
+
47
+ doc_texts = [d["text"] for d in documents]
48
 
49
+ embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
50
 
51
+ doc_embs = embedding_model.encode(
52
+ doc_texts,
53
+ convert_to_numpy=True,
54
+ show_progress_bar=True
55
+ ).astype("float32")
56
+
57
+ dim = doc_embs.shape[1]
58
+ index = faiss.IndexFlatL2(dim)
59
+ index.add(doc_embs)
60
+
61
+ print("FAISS index built. Total docs =", index.ntotal)
62
+
63
+ # ============================================================
64
+ # 4. LLM + LoRA λ‘œλ“œ
65
+ # ============================================================
66
  model = AutoModelForCausalLM.from_pretrained(
67
+ BASE_MODEL,
68
+ torch_dtype=torch.float16,
69
+ device_map="auto",
70
+ trust_remote_code=True
71
  )
72
 
73
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
 
 
 
 
 
 
 
74
 
75
+ if tokenizer.pad_token is None:
76
+ tokenizer.pad_token = tokenizer.eos_token
77
+
78
+ print("Loading LoRA...")
79
+ model = PeftModel.from_pretrained(
80
+ model,
81
+ LORA_DIR,
82
+ torch_dtype=torch.float16,
83
+ device_map="auto",
84
+ )
85
+ model = model.to(device)
86
+ model.eval()
87
+
88
+ print("Model + LoRA loaded successfully.")
89
+
90
+ # ============================================================
91
+ # 5. RAG 검색 ν•¨μˆ˜
92
+ # ============================================================
93
+ def retrieve(query, k=3):
94
+ q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
95
+ D, I = index.search(q_emb, k)
96
+ return [documents[i] for i in I[0]]
97
+
98
+ # ============================================================
99
+ # 6. ν”„λ‘¬ν”„νŠΈ ꡬ성 (λ„€ μ½”λ“œ κ·ΈλŒ€λ‘œ)
100
+ # ============================================================
101
+ def build_prompt(persona, instruction, query, retrieved_docs):
102
+ context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
103
+ return f"""
104
+ ### 페λ₯΄μ†Œλ‚˜:
105
+ {persona}
106
+
107
+ ### 참고사항:
108
+ {instruction}
109
+
110
+ ### κ·œμ •:
111
+ {context}
112
+
113
+ ### 질문:
114
+ {query}
115
+
116
+ ### λ‹΅λ³€:
117
+ """
118
+
119
+ # ============================================================
120
+ # 7. Streaming Chat ν•¨μˆ˜
121
+ # ============================================================
122
+ def stream_chat(persona, instruction, user_query, k=3, max_new_tokens=256):
123
+
124
+ retrieved = retrieve(user_query, k=k)
125
+ prompt = build_prompt(persona, instruction, user_query, retrieved)
126
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
127
 
128
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
129
+ END_TOKENS = ["End of Answer", "###", "---", "β€»"]
130
+
131
+ def run_generation():
132
+ with torch.no_grad():
133
+ model.generate(
134
+ **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ do_sample=True,
137
+ top_p=0.9,
138
+ temperature=0.7,
139
+ repetition_penalty=1.2,
140
+ pad_token_id=tokenizer.pad_token_id,
141
+ eos_token_id=tokenizer.eos_token_id,
142
+ streamer=streamer,
143
+ use_cache=True
144
+ )
145
+
146
+ thread = threading.Thread(target=run_generation)
147
+ thread.start()
148
+
149
+ full = ""
150
+ for token in streamer:
151
+ full += token
152
+ yield token
153
+
154
+ # ============================================================
155
+ # 8. Gradio UI
156
+ # ============================================================
157
+ def gradio_answer(persona, query):
158
+
159
+ instruction = ("""
160
+ 당신은 ν•΄λ‹Ή 페λ₯΄μ†Œλ‚˜μ˜ 성격을 κ°€μ§„ μ‹¬νŒκ΄€μž…λ‹ˆλ‹€.
161
+ λ°˜λ“œμ‹œ 3λ¬Έμž₯만 λ§ν•˜μ‹­μ‹œμ˜€.
162
+ 각 λ¬Έμž₯은 30자 μ΄λ‚΄λ‘œ μ œν•œν•©λ‹ˆλ‹€.
163
+ κ·œμ •μ— μš°μ„ μœΌλ‘œ κ·Όκ±°ν•˜μ—¬ λ‹΅ν•˜μ‹œμ˜€.
164
+ νŒλ‹¨ κ·Όκ±°κ°€ ν¬ν•¨λœ λ‹΅μ•ˆλ§Œ μƒμ„±ν•˜μ‹œμ˜€.
165
+ 같은 말을 λ°˜λ³΅ν•˜λŠ” 것을 μ ˆλŒ€ κΈˆν•¨.
166
+ μ˜λ¬Έμ— ν™•μ‹€ν•˜κ²Œ μž…μž₯을 밝힐 것.
167
+ λ°˜λ“œμ‹œ 3λ¬Έμž₯만 λ§ν•˜μ‹­μ‹œμ˜€.
168
+ 각 λ¬Έμž₯은 30자 μ΄λ‚΄λ‘œ μ œν•œν•©λ‹ˆλ‹€.
169
+ """)
170
+
171
+ return stream_chat(persona, instruction, query)
172
+
173
+ with gr.Blocks() as app:
174
+ gr.Markdown("## πŸ”₯ KORMo LoRA + RAG Streaming Judge")
175
+
176
+ persona = gr.Textbox(label="페λ₯΄μ†Œλ‚˜ μž…λ ₯")
177
+ query = gr.Textbox(label="질문 μž…λ ₯")
178
+
179
+ output = gr.Textbox(label="응닡", lines=8)
180
+ btn = gr.Button("생성")
181
+
182
+ btn.click(
183
+ gradio_answer,
184
+ inputs=[persona, query],
185
+ outputs=output
186
  )
187
 
188
+ app.launch()