Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| from prompt import PROMPTS | |
| def normalized_embedding(emb: np.ndarray) -> np.ndarray: | |
| return emb / np.linalg.norm(emb) | |
| def build_faiss_index_from_embeddings(emb_map: dict[str, np.ndarray]): | |
| keys = list(emb_map.keys()) | |
| matrix = np.stack([normalized_embedding(emb_map[k]) for k in keys]).astype("float32") | |
| index = faiss.IndexFlatIP(matrix.shape[1]) | |
| index.add(matrix) | |
| return index, keys | |
| def generate_answer( | |
| tokenizer, model, | |
| system_prompt: str, query: str, | |
| context: str = "", conversation_history=None | |
| ) -> str: | |
| B = "<|begin_of_text|>" | |
| SS= "<|start_header_id|>system<|end_header_id|>" | |
| SU= "<|start_header_id|>user<|end_header_id|>" | |
| SA= "<|start_header_id|>assistant<|end_header_id|>" | |
| E = "<|eot_id|>" | |
| system_block = f"{B}\n{SS}\n{system_prompt}\n{E}\n" | |
| conv = "" | |
| if conversation_history: | |
| for msg in conversation_history: | |
| role = msg["role"] | |
| content = msg["content"].strip() | |
| tag = SU if role=="user" else SA | |
| conv += f"{tag}\n{content}\n{E}\n" | |
| if context: | |
| user_block = f"{query}\n\n### ์ธ๋ถ ์ง์ ###\n{context}" | |
| else: | |
| user_block = query | |
| conv += f"{SU}\n{user_block}\n{E}\n" | |
| conv += f"{SA}\n" | |
| prompt = system_block + conv | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device) | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.8, | |
| do_sample=True, | |
| ) | |
| decoded = tokenizer.decode(out[0], skip_special_tokens=False) | |
| answer = decoded.split(prompt, 1)[-1] | |
| for tok in [B, SS, SU, SA, E]: | |
| answer = answer.replace(tok, "") | |
| return answer.strip() | |
| MODEL_ID = "JLee0/staria-pdf-chatbot-lora" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="auto", torch_dtype="auto" | |
| ) | |
| rag_embedder = SentenceTransformer("JLee0/rag-embedder-staria-10epochs") | |
| with open("staria_keys_embed.json", encoding="utf-8") as f: | |
| emb_raw = json.load(f) | |
| emb_map = {k: np.array(v, dtype="float32") for k,v in emb_raw.items()} | |
| index, keys = build_faiss_index_from_embeddings(emb_map) | |
| with open("staria_values.json", encoding="utf-8") as f: | |
| value_map = json.load(f) | |
| SYSTEM_PROMPT = PROMPTS.get("staria_after") | |
| def chat_with_rag(query): | |
| q_emb = rag_embedder.encode(query, convert_to_numpy=True) | |
| q_norm = normalized_embedding(q_emb).astype("float32").reshape(1,-1) | |
| D, I = index.search(q_norm, 1) | |
| score, idx = float(D[0,0]), int(I[0,0]) | |
| if score >= 0.65: | |
| full_key = keys[idx] | |
| file_key, seg_id = full_key.rsplit("_",1) | |
| context = value_map[file_key]["segments"].get(seg_id, "") | |
| else: | |
| context = "" | |
| answer = generate_answer( | |
| tokenizer, model, | |
| system_prompt=SYSTEM_PROMPT, | |
| query=query, | |
| context=context, | |
| conversation_history=None | |
| ) | |
| return answer | |
| demo = gr.ChatInterface( | |
| fn=chat_with_rag, | |
| system_message="ํ๋ ์คํ๋ฆฌ์ Q&A ์ฑ๋ด์ ์ค์ ๊ฑธ ํ์ํฉ๋๋ค! ์ง๋ฌธ์ ์ ๋ ฅํด ์ฃผ์ธ์.", | |
| examples=[ | |
| ["์์ง์ค์ผ ๊ต์ฒด ์ฃผ๊ธฐ๊ฐ ์ด๋ป๊ฒ ๋๋์?"], | |
| ["๋นํธ์ธ ์บ ๋ฐ์ดํฐ๋ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ๋ฉด ๋ผ?"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |