|
|
|
|
|
import json |
|
|
import threading |
|
|
import gradio as gr |
|
|
import torch |
|
|
import faiss |
|
|
import numpy as np |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print("Device:", device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL = "KORMo-Team/KORMo-10B-sft" |
|
|
LORA_DIR = "peft_lora" |
|
|
DOC_PATH = "rule.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(DOC_PATH, "r", encoding="utf-8") as f: |
|
|
documents = json.load(f) |
|
|
|
|
|
doc_texts = [d["text"] for d in documents] |
|
|
|
|
|
embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device) |
|
|
doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32") |
|
|
|
|
|
index = faiss.IndexFlatL2(doc_embs.shape[1]) |
|
|
index.add(doc_embs) |
|
|
|
|
|
def retrieve(query, k=3): |
|
|
q = embedding_model.encode([query], convert_to_numpy=True).astype("float32") |
|
|
D, I = index.search(q, k) |
|
|
return [documents[i] for i in I[0]] |
|
|
|
|
|
print("FAISS ready, docs:", index.ntotal) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
LORA_DIR, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
model.eval() |
|
|
print("Model + LoRA loaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_prompt(persona, instruction, query, retrieved_docs): |
|
|
context = "\n".join([f"- {d['text']}" for d in retrieved_docs]) |
|
|
return f""" |
|
|
### ํ๋ฅด์๋: |
|
|
{persona} |
|
|
|
|
|
### ์ฐธ๊ณ ์ฌํญ: |
|
|
{instruction} |
|
|
|
|
|
### ๊ท์ : |
|
|
{context} |
|
|
|
|
|
### ์ง๋ฌธ: |
|
|
{query} |
|
|
|
|
|
### ๋ต๋ณ: |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_stream(persona, instruction, query, max_new_tokens=256): |
|
|
retrieved = retrieve(query, k=3) |
|
|
prompt = build_prompt(persona, instruction, query, retrieved) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
def run_generate(): |
|
|
with torch.no_grad(): |
|
|
model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
thread = threading.Thread(target=run_generate) |
|
|
thread.start() |
|
|
|
|
|
accumulated = "" |
|
|
for token in streamer: |
|
|
accumulated += token |
|
|
yield accumulated |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_once(persona, instruction, query, max_new_tokens=256): |
|
|
retrieved = retrieve(query, k=3) |
|
|
prompt = build_prompt(persona, instruction, query, retrieved) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return text.replace(prompt, "").strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
persona_group = [ |
|
|
("๋น์ ์ ์์น์ ์งํค๋ ์ํฉ์ ๋ฐ๋ผ ์ ์ฐํ๊ฒ ํ๋จํ๋ ์๊ฐ์ ๊ฐ์ง๊ณ ์๋ค. ๊ฐ์ธ์ ๋ฅ๋ ฅ๊ณผ ๊ธฐ์ฌ๋๋ฅผ ์ค์ํ๊ฒ ์๊ฐํ์ฌ ์ฑ๊ณผ์ ๋ฐ๋ฅธ ์ฐจ๋ฑ ๋์ฐ๋ฅผ ์ ๋นํ๋ค๊ณ ํ๋จํ๋ฉฐ, ๋ณํ์ ํ์ ์ ์ต์ฐ์ ์ผ๋ก ์ฌ๊ฒจ ๊ด์ต๋ณด๋ค ๊ฐ์ ์ ์ ํํ๋ค. ๋ํ ๋ด๋ถ์ ๋จธ๋ฌด๋ฅด๊ธฐ๋ณด๋ค ์ธ๋ถ์์ ์ฐ๊ณ์ ํ์
์ ์ ๊ทน์ ์ผ๋ก ์ถ๊ตฌํ๋ฉฐ, ํ์ ํ๋๊ณผ ์น๋ชฉ ํ๋์ ๊ท ํ์ ํตํด ๊ฑด๊ฐํ ๊ณต๋์ฒด ๋ฌธํ๋ฅผ ์งํฅํ๋ค. ๋์ธ์ ์ผ๋ก ๋ณด์ฌ์ค ์ ์๋ ํ์คํ ์ฑ๊ณผ์ ์์ฑ๋๋ฅผ ์ค์ํ๋ฉด์๋, ๋จ๊ธฐ์ ํด๊ฒฐ๊ณผ ์ฅ๊ธฐ์ ๊ธฐ๋ฐ ๋ง๋ จ ์ฌ์ด์์ ๊ท ํ์ ์ ์งํ๋ ค ๋
ธ๋ ฅํ๋ค.", '๋ฐ์ธ์ฐ'), |
|
|
("๋น์ ์ ๊ณต์ ํ ๊ท์น๊ณผ ์์น์ ์ค์ํ๋ฉด์, ๊ฐ์ธ์ ์ฑ๊ณผ์ ๋ฅ๋ ฅ์ ์ธ์ ํด ์ฐจ๋ฑ์ ๋๊ณ ๋ฐฐ๋ถํฉ๋๋ค. ์ ํต์ ์กด์คํ๋ ์ ์ง์ ์ธ ๊ฐ์ ์ ์์ฉํ๋ฉฐ, ๋ด๋ถ ํ๋์ ๋จธ๋ฌด๋ฅด์ง ์๊ณ ์ธ๋ถ์์ ํ์
๊ณผ ๋คํธ์ํฌ๋ฅผ ์ ๊ทน์ ์ผ๋ก ์ถ๊ตฌํฉ๋๋ค. ํ์ ๊ฐ ์ ๋์ ์ฆ๊ฑฐ์์ ์ค์์ํ๊ณ , ์์ฑ๋ ๋์ ๊ฒฐ๊ณผ๋ฌผ๊ณผ ๊ณผ์ ์์์ ๋ฐฐ์์ ๋ชจ๋ ์ค์ํ๋ฉฐ, ๋น์ฅ์ ๋ฌธ์ ํด๊ฒฐ๊ณผ ์ฅ๊ธฐ์ ๊ธฐ๋ฐ ๊ตฌ์ถ์ ๋์์ ๊ณ ๋ คํฉ๋๋ค.",'๊น์ฐฝ์ค'), |
|
|
("๊ท์จ๊ณผ ์์จ์ ๊ท ํ์ ์งํค๋ฉฐ, ๋ฅ๋ ฅ๊ณผ ์ฑ๊ณผ๋ฅผ ๊ธฐ์ค์ผ๋ก ํ๋จํ๋ค. ์ ํต์ ์ ์งํ๋ ์ ์ง์ ๊ฐ์ ์ ์ถ๊ตฌํ๊ณ , ์ธ๋ถ์์ ํ์
์ ์ ๊ทน์ ์ผ๋ก ๋ชจ์ํ๋ค. ์ฆ๊ฑฐ์ด ๋ถ์๊ธฐ ์์์ ํ์ตํ๋ฉฐ ๊ฐ์ธ์ ์ฑ์ฅ์ ์ค์ํ๊ณ , ๋จ๊ธฐ ์ฑ๊ณผ๋ณด๋ค ๋์๋ฆฌ์ ์ฅ๊ธฐ์ ๊ธฐ๋ฐ์ ์ฐ์ ํ๋ค.", '์ด์๊ธฐ'), |
|
|
("๊ท์จ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ง๋ง ์ ์ฐํ๋ฉฐ, ๋ถ๋ฐฐ๋ ์ค๋ฆฝ์ ์ด๊ณ ๊ฐ์ ์ ์ถ๊ตฌํ๋ค. ์ธ๋ถ ์ฐ๊ณ๋ฅผ ์ ๋นํ ํ์ฉํ๋ฉฐ ํ์ ยท์น๋ชฉ ๋ชจ๋ ์ํฉ์ ๋ฐ๋ผ ์ ํํ๊ณ , ๊ฐ์์ฑ๊ณผ ์ฅ๊ธฐ ๊ธฐ๋ฐ์ ์กฐํ๋กญ๊ฒ ๊ณ ๋ คํ๋ค.", '์ฑํ'), |
|
|
("์์จ์ ์กด์คํ๋ ์ต์ํ์ ๊ท์จ์ ์ ์งํ๋ฉฐ, ๊ธฐ์ฌ๋์ ๊ฐ์ ์ ๊ท ํ ์๊ฒ ๋ฐ์ํ๋ค. ๋ด๋ถ์ ์ธ๋ถ ํ๋์ ์ํฉ์ ๋ฐ๋ผ ์กฐ์ ํ๊ณ ํ์ ๊ณผ ์น๋ชฉ ๋ชจ๋๋ฅผ ํฌ์ฉํ๋ฉฐ, ์ฑ์ฅ๊ณผ ์ฅ๊ธฐ ๊ธฐ๋ฐ์ ์ค์ํ๋ ์ค์ฉ์ ์ด์์ ์ ํธํ๋ค.", '์ฉ์ฐ'), |
|
|
("๊ท์จ๊ณผ ๊ณต์ ์ ๊ธฐ๋ฐ์ผ๋ก ์์ ์ ์ธ ์ด์์ ์ถ๊ตฌํ๋ฉฐ, ๊ท ๋ฑยท๊ฐ์ ยท์น๋ชฉ ๊ฐ์ ๊ท ํ์ ์ค์ํ๋ค. ๋ด๋ถ ์ค์ฌ์ด๋ ํ์์ ๋ฐ๋ผ ์ธ๋ถ ํ๋ ฅ์ ์์ฉํ๊ณ , ์ฑ์ฅ๊ณผ ์ฅ๊ธฐ ๊ธฐ๋ฐ์ ํจ๊ป ๊ณ ๋ คํ๋ ์ค์ฉ์ ํ๋จ์ ์งํฅํ๋ค.",'ํ์ง') |
|
|
] |
|
|
|
|
|
instruction_text = """ |
|
|
ํด๋น ํ๋ฅด์๋์ ์ฑ๊ฒฉ์ ๊ฐ์ง ์ฌํ๊ด์
๋๋ค. |
|
|
๋ฐ๋์ 3๋ฌธ์ฅ๋ง ๋งํ์ญ์์ค. |
|
|
๊ฐ ๋ฌธ์ฅ์ 30์ ์ด๋ด๋ก ์ ํํฉ๋๋ค. |
|
|
๊ท์ ์ ๊ทผ๊ฑฐํ์ฌ ๋ตํ์์ค. |
|
|
๋ฐ๋ณต ๊ธ์ง, ํ๋จ ๊ทผ๊ฑฐ ํ์. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_all_streaming(query): |
|
|
for persona, name in persona_group: |
|
|
yield f"## ๐ค {name}\n" |
|
|
for partial in generate_stream(persona, instruction_text, query): |
|
|
yield partial |
|
|
yield "\n\n---\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_all_api(query): |
|
|
out = "" |
|
|
for persona, name in persona_group: |
|
|
out += f"## ๐ค {name}\n" |
|
|
text = generate_once(persona, instruction_text, query) |
|
|
out += text + "\n\n---\n\n" |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# ๐ฅ KORMo LoRA + RAG (Streaming UI + API)") |
|
|
|
|
|
user_input = gr.Textbox( |
|
|
label="์ง๋ฌธ ์
๋ ฅ", |
|
|
value="3๋ฒ ์ด์์ ๊ฒฐ์์ ํ์ง๋ง ์ค๋ ฅ์ ๋ฐ์ด๋ ์ ํ์์ ์ด๋ป๊ฒ ํด์ผ ํ ๊น?" |
|
|
) |
|
|
|
|
|
|
|
|
output_stream = gr.Markdown() |
|
|
run_btn = gr.Button("๐ ์คํ(Streaming UI)") |
|
|
run_btn.click(fn=run_all_streaming, inputs=[user_input], outputs=[output_stream]) |
|
|
|
|
|
|
|
|
api_output = gr.Textbox(label="API ๋ฐํ ๊ฒฐ๊ณผ", lines=15) |
|
|
run_btn_api = gr.Button("๐ ์คํ(API)") |
|
|
run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output], api_name="start_api") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|