Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import faiss | |
| import re | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from sentence_transformers import SentenceTransformer | |
| from prompt import PROMPTS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalized_embedding(emb: np.ndarray) -> np.ndarray: | |
| return emb / np.linalg.norm(emb) | |
| def load_embeddings(emb_path: str) -> dict[str, np.ndarray]: | |
| raw = json.load(open(emb_path, encoding="utf-8")) | |
| return {k: np.array(v, dtype="float32") for k, v in raw.items()} | |
| def build_faiss_index_from_embeddings(emb_map: dict[str, np.ndarray]): | |
| keys = list(emb_map.keys()) | |
| matrix = np.stack([normalized_embedding(emb_map[k]) for k in keys]).astype("float32") | |
| index = faiss.IndexFlatIP(matrix.shape[1]) | |
| index.add(matrix) | |
| return index, keys | |
| def load_value_segments(value_path: str) -> dict[str, dict[str,str]]: | |
| return json.load(open(value_path, encoding="utf-8")) | |
| def generate_answer(tokenizer, model, system_prompt:str, query: str, context: str = "", conversation_history=None) -> str: | |
| B = "<|begin_of_text|>" | |
| SS = "<|start_header_id|>system<|end_header_id|>" | |
| SU = "<|start_header_id|>user<|end_header_id|>" | |
| SA = "<|start_header_id|>assistant<|end_header_id|>" | |
| EOT = "<|eot_id|>" | |
| system_block = f"{B}\n{SS}\n{system_prompt}\n{EOT}\n" | |
| conv_text = "" | |
| if conversation_history: | |
| for msg in conversation_history: | |
| role = msg["role"] | |
| content = msg["content"].strip() | |
| tag = SU if role=="user" else SA | |
| conv_text += f"{tag}\n{content}\n{EOT}\n" | |
| if context: | |
| user_block = f"{query}\n\n### μΈλΆ μ§μ ###\n{context}" | |
| else: | |
| user_block = query | |
| conv_text += f"{SU}\n{user_block}\n{EOT}\n" | |
| conv_text += f"{SA}\n" | |
| prompt = system_block + conv_text | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device) | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.8, | |
| ) | |
| decoded = tokenizer.decode(out[0], skip_special_tokens=False) | |
| answer = decoded.split(prompt, 1)[-1] | |
| for tok in [B, SS, SU, SA, EOT]: | |
| answer = answer.replace(tok, "") | |
| return answer.strip() | |
| def post_process_answer(raw: str, prev_answer: str = "") -> str: | |
| if not raw: | |
| return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
| m = re.search( | |
| r"\<\|start_header_id\>assistant\<\|end_header_id\>(.*?)\<\|eot_id\>", | |
| raw, re.DOTALL | |
| ) | |
| if m: | |
| ans = m.group(1).strip() | |
| else: | |
| ans = raw.strip() | |
| ans = re.sub(r"\<\|.*?\|\>", "", ans).strip() | |
| if ans.lower().count("assistant") >= 4: | |
| return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
| if not ans or ans == prev_answer.strip(): | |
| return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
| return ans | |
| def answer_query( | |
| query: str, | |
| emb_key_path: str, | |
| value_text_path: str, | |
| tokenizer, | |
| model, | |
| system_prompt: str, | |
| rag_model, | |
| conversation_history=None, | |
| threshold: float = 0.65 | |
| ) -> str: | |
| emb_map, _ = load_embeddings(emb_key_path), None | |
| index, keys= build_faiss_index_from_embeddings(emb_map) | |
| value_map = load_value_segments(value_text_path) | |
| q_emb = rag_model.encode(query, convert_to_tensor=True).cpu().numpy().squeeze() | |
| q_norm= normalized_embedding(q_emb).astype("float32").reshape(1,-1) | |
| D, I = index.search(q_norm, 1) | |
| score, idx = float(D[0,0]), int(I[0,0]) | |
| if score >= threshold: | |
| full_key = keys[idx] | |
| file_key, seg_id = full_key.rsplit("_",1) | |
| context = value_map[file_key]["segments"].get(seg_id, "") | |
| print(f"β μ μ¬λ: {score:.4f}, context μ€λΉλ¨ β '{context[:30]}β¦'") | |
| else: | |
| context = "" | |
| print(f"β μ μ¬λ {score:.4f} < {threshold} β μΈλΆ μ§μ λ―Έμ¬μ©") | |
| raw = generate_answer( | |
| tokenizer=tokenizer, | |
| model=model, | |
| system_prompt=system_prompt, | |
| query=query, | |
| context=context, | |
| conversation_history=conversation_history | |
| ) | |
| answer_text = post_process_answer(raw) | |
| print(f"\nβ Answer: {answer_text}\n") | |
| return answer_text | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EMB_KEY_PATH = "staria_keys_embed.json" | |
| VALUE_TEXT_PATH = "staria_values.json" | |
| MODEL_ID = "JLee0/staria-pdf-chatbot-lora" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| load_in_8bit=True, # λλ 4bit | |
| device_map="auto" | |
| ) | |
| rag_embedder = SentenceTransformer("JLee0/rag-embedder-staria-10epochs") | |
| SYSTEM_PROMPT = PROMPTS["staria_after"] | |
| def chat(query, history): | |
| conv = [] | |
| for u, a in history or []: | |
| conv.append({"role":"user", "content":u}) | |
| conv.append({"role":"assistant", "content":a}) | |
| return answer_query( | |
| query=query, | |
| emb_key_path=EMB_KEY_PATH, | |
| value_text_path=VALUE_TEXT_PATH, | |
| tokenizer=tokenizer, | |
| model=model, | |
| system_prompt=SYSTEM_PROMPT, | |
| rag_model=rag_embedder, | |
| conversation_history=conv, | |
| threshold=0.65 | |
| ) | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| examples=[ | |
| ["μμ§ μ€μΌ κ΅μ²΄ μ μ£Όμν΄μΌ ν μ¬νμ 무μμΈκ°μ?"], | |
| ["λΉνΈμΈ μΊ λ°μ΄ν°λ μ΄λ»κ² μ²λ¦¬νλ©΄ λΌ?"], | |
| ["μμ μ‘ λΆμΆ κΈ°λ₯μ μ¬μ©ν ν μ€μμΉλ₯Ό λ€μ μμμΉλ‘ λλ €μΌ νλμ?"], | |
| ["μ°¨λ μλμ κΊΌλ μμ΄μ»¨ μ€μ μ΄ μ μ§λλμ?"] | |
| ], | |
| title="νλ μ€ν리μ Q&A μ±λ΄", | |
| description="μ±λ΄μ μ€μ κ²μ νμν©λλ€! μ§λ¬Έμ μ λ ₯ν΄ μ£ΌμΈμ." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |