qewrufda commited on
Commit
caa1ded
·
verified ·
1 Parent(s): 90e0b65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -184
app.py CHANGED
@@ -1,184 +1,211 @@
1
- import os
2
- import json
3
- import torch
4
- from huggingface_hub import login
5
- from sentence_transformers import SentenceTransformer
6
- import faiss
7
- import numpy as np
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from peft import PeftModel
10
- import threading
11
-
12
- # ============================================================
13
- # 1. 환경 설정 + 로그인
14
- # ============================================================
15
- HF_TOKEN = os.getenv("HF_TOKEN") # ← secret variable에서 불러옴
16
- login(token=HF_TOKEN)
17
-
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- print("Device:", device)
20
-
21
- # ============================================================
22
- # 2. 경로 설정
23
- # ============================================================
24
- BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
25
- LORA_DIR = "./peft_lora" # 서버 경로
26
- DOC_PATH = "./rule.json" # 문서 파일
27
-
28
- # ============================================================
29
- # 3. RAG 문서 로드 + FAISS 구축
30
- # ============================================================
31
- with open(DOC_PATH, "r", encoding="utf-8") as f:
32
- documents = json.load(f)
33
-
34
- doc_texts = [d["text"] for d in documents]
35
-
36
- embedding_model = SentenceTransformer(
37
- "jhgan/ko-sroberta-multitask",
38
- device=device
39
- )
40
-
41
- doc_embs = embedding_model.encode(
42
- doc_texts, convert_to_numpy=True
43
- ).astype("float32")
44
-
45
- dim = doc_embs.shape[1]
46
- index = faiss.IndexFlatL2(dim)
47
- index.add(doc_embs)
48
-
49
- print("FAISS index built:", index.ntotal)
50
-
51
- # ============================================================
52
- # 4. LLM + LoRA 로드
53
- # ============================================================
54
- model = AutoModelForCausalLM.from_pretrained(
55
- BASE_MODEL,
56
- torch_dtype=torch.float16,
57
- device_map="auto",
58
- trust_remote_code=True
59
- )
60
-
61
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
62
- if tokenizer.pad_token is None:
63
- tokenizer.pad_token = tokenizer.eos_token
64
-
65
- model = PeftModel.from_pretrained(
66
- model,
67
- LORA_DIR,
68
- torch_dtype=torch.float16,
69
- device_map="auto",
70
- )
71
-
72
- model = model.to(device)
73
- model.eval()
74
-
75
- # ============================================================
76
- # 5. RAG 검색 함수
77
- # ============================================================
78
- def retrieve(query, k=3):
79
- q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
80
- D, I = index.search(q_emb, k)
81
- return [documents[i] for i in I[0]]
82
-
83
- # ============================================================
84
- # 6. 프롬프트 생성
85
- # ============================================================
86
- def build_prompt(persona, instruction, query, retrieved_docs):
87
- context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
88
- return f"""
89
- ### 페르소나:
90
- {persona}
91
-
92
- ### 참고사항:
93
- {instruction}
94
-
95
- ### 규정:
96
- {context}
97
-
98
- ### 질문:
99
- {query}
100
-
101
- ### 답변:
102
- """
103
-
104
- # ============================================================
105
- # 7. Streaming Chat
106
- # ============================================================
107
- def stream_chat(persona, instruction, user_query, max_new_tokens=256):
108
-
109
- retrieved = retrieve(user_query, k=3)
110
- prompt = build_prompt(persona, instruction, user_query, retrieved)
111
-
112
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
113
-
114
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
115
-
116
- END_TOKENS = [
117
- "End of Answer", "### 검토 결과:", "### 최종 답변",
118
- "※", ">", "**답변**", "---", "###", "**"
119
- ]
120
-
121
- def run_gen():
122
- with torch.no_grad():
123
- model.generate(
124
- **inputs,
125
- max_new_tokens=max_new_tokens,
126
- do_sample=True,
127
- top_p=0.9,
128
- temperature=0.7,
129
- repetition_penalty=1.2,
130
- streamer=streamer
131
- )
132
-
133
- thread = threading.Thread(target=run_gen)
134
- thread.start()
135
-
136
- full = ""
137
- for tok in streamer:
138
- print(tok, end="", flush=True)
139
- full += tok
140
- for e in END_TOKENS:
141
- if e in full:
142
- print()
143
- return
144
-
145
- print()
146
-
147
- # ============================================================
148
- # 8. 페르소나 목록
149
- # ============================================================
150
- persona_group = [
151
- ("당신은 원칙을 지키되 상황에 따라 유연하게 판단하는 시각을 가지고 있다...", "박세연"),
152
- ("당신은 공정한 규칙과 원칙을 중시하면서, 개인의 성과와 능력을 인정해 차등...", "김창준"),
153
- ("규율과 자율의 균형을 지키며, 능력과 성과를 기준으로 판단한다...", "이상기"),
154
- ("규율을 기반으로 하지만 유연하며, 분배는 중립적이고 개선을 추구한다...", "채훈"),
155
- ("자율을 존중하되 최소한의 규율을 유지하며, 기여도와 개선을 균형 있게 반영...", "용우"),
156
- ("규율과 공정을 기반으로 안정적인 운영을 추구하며, 균등·개선·친목 간의 균형...", "형진")
157
- ]
158
-
159
- # ============================================================
160
- # 9. 프로그램 실행 (입력 받는 부분)
161
- # ============================================================
162
- if __name__ == "__main__":
163
-
164
- query = input("질문을 입력하세요: ")
165
-
166
- instruction = """
167
- 당신은 해당 페르소나의 성격을 가진 심판관입니다.
168
- 반드시 3문장만 말하십시오.
169
- 문장은 30자 이내.
170
- 규정을 우선하여 답하세요.
171
- 판단 근거 포함.
172
- 반복 금지.
173
- """
174
-
175
- for persona_text, persona_name in persona_group:
176
- print("\n====================")
177
- print(f"### {persona_name} ###")
178
- print("====================")
179
-
180
- stream_chat(
181
- persona=persona_text,
182
- instruction=instruction,
183
- user_query=query
184
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import threading
5
+ import numpy as np
6
+ import faiss
7
+
8
+ from fastapi import FastAPI
9
+ from pydantic import BaseModel
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+
12
+ from huggingface_hub import login
13
+ from sentence_transformers import SentenceTransformer
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
15
+ from peft import PeftModel
16
+
17
+
18
+ # ============================================================
19
+ # FastAPI 설정
20
+ # ============================================================
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # ============================================================
32
+ # 1. 환경 설정 + 로그인
33
+ # ============================================================
34
+ HF_TOKEN = os.getenv("HF_TOKEN")
35
+ if HF_TOKEN is None:
36
+ raise ValueError("환경 변수 HF_TOKEN이 설정되지 않음")
37
+
38
+ login(token=HF_TOKEN)
39
+
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print("Device:", device)
42
+
43
+ # ============================================================
44
+ # 2. 경로 설정
45
+ # ============================================================
46
+ BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
47
+ LORA_DIR = "./peft_lora"
48
+ DOC_PATH = "./rule.json"
49
+
50
+
51
+ # ============================================================
52
+ # 3. RAG 문서 로드 + FAISS 구축
53
+ # ============================================================
54
+ with open(DOC_PATH, "r", encoding="utf-8") as f:
55
+ documents = json.load(f)
56
+
57
+ doc_texts = [d["text"] for d in documents]
58
+
59
+ embedding_model = SentenceTransformer(
60
+ "jhgan/ko-sroberta-multitask",
61
+ device=device
62
+ )
63
+
64
+ doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32")
65
+
66
+ dim = doc_embs.shape[1]
67
+ index = faiss.IndexFlatL2(dim)
68
+ index.add(doc_embs)
69
+
70
+ print("FAISS index built:", index.ntotal)
71
+
72
+
73
+ # ============================================================
74
+ # 4. LLM + LoRA 로드
75
+ # ============================================================
76
+ model = AutoModelForCausalLM.from_pretrained(
77
+ BASE_MODEL,
78
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
79
+ device_map="auto",
80
+ trust_remote_code=True
81
+ )
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
84
+ if tokenizer.pad_token is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ model = PeftModel.from_pretrained(
88
+ model,
89
+ LORA_DIR,
90
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
91
+ device_map="auto",
92
+ )
93
+
94
+ model.eval()
95
+
96
+
97
+ # ============================================================
98
+ # 5. RAG 검색 함수
99
+ # ============================================================
100
+ def retrieve(query, k=3):
101
+ q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
102
+ D, I = index.search(q_emb, k)
103
+ return [documents[i] for i in I[0]]
104
+
105
+
106
+ # ============================================================
107
+ # 6. 프롬프트 생성
108
+ # ============================================================
109
+ def build_prompt(persona, instruction, query, retrieved_docs):
110
+ context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
111
+ return f"""
112
+ ### 페르소나:
113
+ {persona}
114
+
115
+ ### 참고사항:
116
+ {instruction}
117
+
118
+ ### 규정:
119
+ {context}
120
+
121
+ ### 질문:
122
+ {query}
123
+
124
+ ### 답변:
125
+ """
126
+
127
+
128
+ # ============================================================
129
+ # 7. Streaming Chat (동기적 모아서 반환)
130
+ # ============================================================
131
+ def run_chat(persona, instruction, user_query, max_new_tokens=256):
132
+
133
+ retrieved = retrieve(user_query, k=3)
134
+ prompt = build_prompt(persona, instruction, user_query, retrieved)
135
+
136
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
137
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
138
+
139
+ def generate():
140
+ with torch.no_grad():
141
+ model.generate(
142
+ **inputs,
143
+ max_new_tokens=max_new_tokens,
144
+ do_sample=True,
145
+ top_p=0.9,
146
+ temperature=0.7,
147
+ repetition_penalty=1.2,
148
+ streamer=streamer
149
+ )
150
+
151
+ thread = threading.Thread(target=generate)
152
+ thread.start()
153
+
154
+ full_result = ""
155
+ for token in streamer:
156
+ full_result += token
157
+
158
+ return full_result.strip()
159
+
160
+
161
+ # ============================================================
162
+ # 8. 페르소나 목록
163
+ # ============================================================
164
+ persona_group = [
165
+ ("당신은 원칙을 지키되 상황에 따라 유연하게 판단하는 시각을 가지고 있다...", "박세연"),
166
+ ("당신은 공정한 규칙과 원칙을 중시하면서, 개인의 성과와 능력을 인정해 차등...", "김창준"),
167
+ ("규율과 자율의 균형을 지키며, 능력과 성과를 기준으로 판단한다...", "이상기"),
168
+ ("규율을 기반으로 하지만 유연하며, 분배는 중립적이고 개선을 추구한다...", "채훈"),
169
+ ("자율을 존중하되 최소한의 규율을 유지하며, 기여도와 개선을 균형 있게 반영...", "용우"),
170
+ ("규율과 공정을 기반으로 안정적인 운영을 추구하며, 균등·개선·친목 간의 균형...", "형진")
171
+ ]
172
+
173
+
174
+ # ============================================================
175
+ # 9. API 입력 모델
176
+ # ============================================================
177
+ class UserQuery(BaseModel):
178
+ query: str
179
+
180
+
181
+ # ============================================================
182
+ # 10. 실제 API 라우트
183
+ # ============================================================
184
+ @app.post("/ask")
185
+ async def ask_api(payload: UserQuery):
186
+ user_query = payload.query
187
+
188
+ instruction = """
189
+ 당신은 해당 페르소나의 성격을 가진 심판관입니다.
190
+ 반드시 3문장만 말하십시오.
191
+ 각 문장은 30자 이내.
192
+ 규정을 우선하여 답하세요.
193
+ 판단 근거 포함.
194
+ 반복 금지.
195
+ """
196
+
197
+ results = {}
198
+
199
+ for persona_text, persona_name in persona_group:
200
+ answer = run_chat(persona_text, instruction, user_query)
201
+ results[persona_name] = answer
202
+
203
+ return {"query": user_query, "answers": results}
204
+
205
+
206
+ # ============================================================
207
+ # Root Index
208
+ # ============================================================
209
+ @app.get("/")
210
+ async def root():
211
+ return {"status": "running", "message": "KORMo + LoRA + RAG Persona API"}