Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# app.py
|
| 2 |
-
import os
|
| 3 |
import json
|
| 4 |
import threading
|
| 5 |
import gradio as gr
|
|
@@ -8,15 +7,11 @@ import faiss
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
-
from transformers import
|
| 12 |
-
AutoTokenizer,
|
| 13 |
-
AutoModelForCausalLM,
|
| 14 |
-
TextIteratorStreamer,
|
| 15 |
-
)
|
| 16 |
from peft import PeftModel
|
| 17 |
|
| 18 |
# -----------------------------
|
| 19 |
-
# 0. 환경
|
| 20 |
# -----------------------------
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print("Device:", device)
|
|
@@ -24,9 +19,9 @@ print("Device:", device)
|
|
| 24 |
# -----------------------------
|
| 25 |
# 1. 모델 / 경로 설정
|
| 26 |
# -----------------------------
|
| 27 |
-
BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
|
| 28 |
-
LORA_DIR = "peft_lora"
|
| 29 |
-
DOC_PATH = "rule.json"
|
| 30 |
|
| 31 |
# -----------------------------
|
| 32 |
# 2. RAG 문서 로드 + FAISS 준비
|
|
@@ -50,7 +45,7 @@ def retrieve(query, k=3):
|
|
| 50 |
print("FAISS ready, docs:", index.ntotal)
|
| 51 |
|
| 52 |
# -----------------------------
|
| 53 |
-
# 3.
|
| 54 |
# -----------------------------
|
| 55 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
| 56 |
if tokenizer.pad_token is None:
|
|
@@ -60,14 +55,14 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 60 |
BASE_MODEL,
|
| 61 |
device_map="auto",
|
| 62 |
torch_dtype=torch.float16,
|
| 63 |
-
trust_remote_code=True
|
| 64 |
)
|
| 65 |
|
| 66 |
model = PeftModel.from_pretrained(
|
| 67 |
model,
|
| 68 |
LORA_DIR,
|
| 69 |
device_map="auto",
|
| 70 |
-
torch_dtype=torch.float16
|
| 71 |
)
|
| 72 |
model.eval()
|
| 73 |
print("Model + LoRA loaded")
|
|
@@ -94,7 +89,7 @@ def build_prompt(persona, instruction, query, retrieved_docs):
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
# -----------------------------
|
| 97 |
-
# 5. 스트리밍
|
| 98 |
# -----------------------------
|
| 99 |
def generate_stream(persona, instruction, query, max_new_tokens=256):
|
| 100 |
retrieved = retrieve(query, k=3)
|
|
@@ -127,7 +122,7 @@ def generate_stream(persona, instruction, query, max_new_tokens=256):
|
|
| 127 |
yield accumulated
|
| 128 |
|
| 129 |
# -----------------------------
|
| 130 |
-
# 6. 동기 생성
|
| 131 |
# -----------------------------
|
| 132 |
def generate_once(persona, instruction, query, max_new_tokens=256):
|
| 133 |
retrieved = retrieve(query, k=3)
|
|
@@ -153,24 +148,24 @@ def generate_once(persona, instruction, query, max_new_tokens=256):
|
|
| 153 |
# 7. 페르소나 그룹
|
| 154 |
# -----------------------------
|
| 155 |
persona_group = [
|
| 156 |
-
("원칙을 지키되 유연하게
|
| 157 |
-
("공정한 규칙과
|
| 158 |
-
("규율과
|
| 159 |
-
("
|
| 160 |
-
("
|
| 161 |
-
("규율과
|
| 162 |
]
|
| 163 |
|
| 164 |
instruction_text = """
|
| 165 |
-
|
| 166 |
반드시 3문장만 말하십시오.
|
| 167 |
각 문장은 30자 이내로 제한합니다.
|
| 168 |
-
규정에
|
| 169 |
반복 금지, 판단 근거 필수.
|
| 170 |
"""
|
| 171 |
|
| 172 |
# -----------------------------
|
| 173 |
-
# 8. UI
|
| 174 |
# -----------------------------
|
| 175 |
def run_all_streaming(query):
|
| 176 |
for persona, name in persona_group:
|
|
@@ -180,7 +175,7 @@ def run_all_streaming(query):
|
|
| 180 |
yield "\n\n---\n\n"
|
| 181 |
|
| 182 |
# -----------------------------
|
| 183 |
-
# 9. API
|
| 184 |
# -----------------------------
|
| 185 |
def run_all_api(query):
|
| 186 |
out = ""
|
|
@@ -206,10 +201,12 @@ with gr.Blocks() as demo:
|
|
| 206 |
run_btn = gr.Button("🚀 실행(Streaming UI)")
|
| 207 |
run_btn.click(fn=run_all_streaming, inputs=[user_input], outputs=[output_stream])
|
| 208 |
|
| 209 |
-
# API
|
| 210 |
api_output = gr.Textbox(label="API 반환 결과", lines=15)
|
| 211 |
-
run_btn_api = gr.Button("🔁 실행(API
|
| 212 |
run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output], api_name="start_api")
|
| 213 |
|
| 214 |
-
#
|
|
|
|
|
|
|
| 215 |
demo.launch()
|
|
|
|
| 1 |
# app.py
|
|
|
|
| 2 |
import json
|
| 3 |
import threading
|
| 4 |
import gradio as gr
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from peft import PeftModel
|
| 12 |
|
| 13 |
# -----------------------------
|
| 14 |
+
# 0. 환경 설정
|
| 15 |
# -----------------------------
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
print("Device:", device)
|
|
|
|
| 19 |
# -----------------------------
|
| 20 |
# 1. 모델 / 경로 설정
|
| 21 |
# -----------------------------
|
| 22 |
+
BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
|
| 23 |
+
LORA_DIR = "peft_lora"
|
| 24 |
+
DOC_PATH = "rule.json"
|
| 25 |
|
| 26 |
# -----------------------------
|
| 27 |
# 2. RAG 문서 로드 + FAISS 준비
|
|
|
|
| 45 |
print("FAISS ready, docs:", index.ntotal)
|
| 46 |
|
| 47 |
# -----------------------------
|
| 48 |
+
# 3. 토크나이저 + 모델 로드 (LoRA 포함)
|
| 49 |
# -----------------------------
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
| 51 |
if tokenizer.pad_token is None:
|
|
|
|
| 55 |
BASE_MODEL,
|
| 56 |
device_map="auto",
|
| 57 |
torch_dtype=torch.float16,
|
| 58 |
+
trust_remote_code=True
|
| 59 |
)
|
| 60 |
|
| 61 |
model = PeftModel.from_pretrained(
|
| 62 |
model,
|
| 63 |
LORA_DIR,
|
| 64 |
device_map="auto",
|
| 65 |
+
torch_dtype=torch.float16
|
| 66 |
)
|
| 67 |
model.eval()
|
| 68 |
print("Model + LoRA loaded")
|
|
|
|
| 89 |
"""
|
| 90 |
|
| 91 |
# -----------------------------
|
| 92 |
+
# 5. 스트리밍 UI용
|
| 93 |
# -----------------------------
|
| 94 |
def generate_stream(persona, instruction, query, max_new_tokens=256):
|
| 95 |
retrieved = retrieve(query, k=3)
|
|
|
|
| 122 |
yield accumulated
|
| 123 |
|
| 124 |
# -----------------------------
|
| 125 |
+
# 6. API용 동기 생성
|
| 126 |
# -----------------------------
|
| 127 |
def generate_once(persona, instruction, query, max_new_tokens=256):
|
| 128 |
retrieved = retrieve(query, k=3)
|
|
|
|
| 148 |
# 7. 페르소나 그룹
|
| 149 |
# -----------------------------
|
| 150 |
persona_group = [
|
| 151 |
+
("당신은 원칙을 지키되 상황에 따라 유연하게 판단하는 시각을 가지고 있다. 개인의 능력과 기여도를 중요하게 생각하여 성과에 따른 차등 대우를 정당하다고 판단하며, 변화와 혁신을 최우선으로 여겨 관습보다 개선을 선택한다. 또한 내부에 머무르기보다 외부와의 연계와 협업을 적극적으로 추구하며, 학술 활동과 친목 활동의 균형을 통해 건강한 공동체 문화를 지향한다. 대외적으로 보여줄 수 있는 확실한 성과와 완성도를 중시하면서도, 단기적 해결과 장기적 기반 마련 사이에서 균형을 유지하려 노력한다.", '박세연'),
|
| 152 |
+
("당신은 공정한 규칙과 원칙을 중시하면서, 개인의 성과와 능력을 인정해 차등을 두고 배분합니다. 전통을 존중하되 점진적인 개선을 수용하며, 내부 활동에 머무르지 않고 외부와의 협업과 네트워크를 적극적으로 추구합니다. 회원 간 유대와 즐거움을 중요시하고, 완성도 높은 결과물과 과정에서의 배움을 모두 중시하며, 당장의 문제 해결과 장기적 기반 구축을 동시에 고려합니다.",'김창준'),
|
| 153 |
+
("규율과 자율의 균형을 지키며, 능력과 성과를 기준으로 판단한다. 전통을 유지하되 점진적 개선을 추구하고, 외부와의 협업을 적극적으로 모색한다. 즐거운 분위기 속에서 학습하며 개인의 성장을 중시하고, 단기 성과보다 동아리의 장기적 기반을 우선한다.", '이상기'),
|
| 154 |
+
("규율을 기반으로 하지만 유연하며, 분배는 중립적이고 개선을 추구한다. 외부 연계를 적당히 활용하며 학술·친목 모두 상황에 따라 선택하고, 가시성과 장기 기반을 조화롭게 고려한다.", '채훈'),
|
| 155 |
+
("자율을 존중하되 최소한의 규율을 유지하���, 기여도와 개선을 균형 있게 반영한다. 내부와 외부 활동을 상황에 따라 조절하고 학술과 친목 모두를 포용하며, 성장과 장기 기반을 중시하는 실용적 운영을 선호한다.", '용우'),
|
| 156 |
+
("규율과 공정을 기반으로 안정적인 운영을 추구하며, 균등·개선·친목 간의 균형을 중시한다. 내부 중심이되 필요에 따라 외부 협력을 수용하고, 성장과 장기 기반을 함께 고려하는 실용적 판단을 지향한다.",'형진')
|
| 157 |
]
|
| 158 |
|
| 159 |
instruction_text = """
|
| 160 |
+
해당 페르소나의 성격을 가진 심판관입니다.
|
| 161 |
반드시 3문장만 말하십시오.
|
| 162 |
각 문장은 30자 이내로 제한합니다.
|
| 163 |
+
규정에 근거하여 답하시오.
|
| 164 |
반복 금지, 판단 근거 필수.
|
| 165 |
"""
|
| 166 |
|
| 167 |
# -----------------------------
|
| 168 |
+
# 8. 스트리밍 UI용
|
| 169 |
# -----------------------------
|
| 170 |
def run_all_streaming(query):
|
| 171 |
for persona, name in persona_group:
|
|
|
|
| 175 |
yield "\n\n---\n\n"
|
| 176 |
|
| 177 |
# -----------------------------
|
| 178 |
+
# 9. API용 동기 실행 (문자열 반환)
|
| 179 |
# -----------------------------
|
| 180 |
def run_all_api(query):
|
| 181 |
out = ""
|
|
|
|
| 201 |
run_btn = gr.Button("🚀 실행(Streaming UI)")
|
| 202 |
run_btn.click(fn=run_all_streaming, inputs=[user_input], outputs=[output_stream])
|
| 203 |
|
| 204 |
+
# API 버튼 (동기 반환)
|
| 205 |
api_output = gr.Textbox(label="API 반환 결과", lines=15)
|
| 206 |
+
run_btn_api = gr.Button("🔁 실행(API)")
|
| 207 |
run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output], api_name="start_api")
|
| 208 |
|
| 209 |
+
# -----------------------------
|
| 210 |
+
# 11. Launch
|
| 211 |
+
# -----------------------------
|
| 212 |
demo.launch()
|