Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from llama_index.core import ( | |
| QueryBundle, | |
| Settings, | |
| StorageContext, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.core.retrievers import BaseRetriever | |
| from llama_index.core.schema import NodeWithScore, TextNode | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.llama_cpp import LlamaCPP | |
| from transformers import AutoTokenizer | |
| DATA_DIR = os.environ.get("DATA_DIR", "packages/data_prep/generated") | |
| MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF" | |
| TOKENIZER_REPO = "Qwen/Qwen3.5-4B" | |
| EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" | |
| def download_gguf_model(repo_id: str) -> str: | |
| files = list_repo_files(repo_id) | |
| gguf_files = [f for f in files if f.endswith(".gguf")] | |
| target = next((f for f in gguf_files if "Q4_K_M" in f.upper()), gguf_files[0]) | |
| print(f"Downloading {target} from {repo_id}...") | |
| return hf_hub_download(repo_id=repo_id, filename=target) | |
| class HybridGraphRetriever(BaseRetriever): | |
| """Combines original embedding similarity and LightGCN-enhanced embedding | |
| similarity using a weighted linear formulation for hybrid ranking.""" | |
| def __init__( | |
| self, | |
| data_dir: str, | |
| embed_model: HuggingFaceEmbedding, | |
| alpha: float = 0.5, | |
| top_k: int = 5, | |
| ): | |
| super().__init__() | |
| self._embed_model = embed_model | |
| self._alpha = alpha | |
| self._top_k = top_k | |
| with open(os.path.join(data_dir, "property_graph_store.json")) as f: | |
| pg_data = json.load(f) | |
| with open(os.path.join(data_dir, "id_to_int.json")) as f: | |
| id_to_int = json.load(f) | |
| lightgcn_all = np.load(os.path.join(data_dir, "lightgcn_embeddings.npy")) | |
| node_ids: list[str] = [] | |
| node_texts: list[str] = [] | |
| node_labels: list[str] = [] | |
| orig_list: list[list[float]] = [] | |
| lgcn_list: list[np.ndarray] = [] | |
| for node_id, node_data in pg_data["nodes"].items(): | |
| if node_id not in id_to_int: | |
| continue | |
| emb = node_data.get("embedding") | |
| if not emb: | |
| continue | |
| idx = id_to_int[node_id] | |
| node_ids.append(node_id) | |
| node_texts.append(node_data.get("text", "")) | |
| node_labels.append(node_data.get("label", "")) | |
| orig_list.append(emb) | |
| lgcn_list.append(lightgcn_all[idx]) | |
| self._node_ids = node_ids | |
| self._node_texts = node_texts | |
| self._node_labels = node_labels | |
| orig = np.array(orig_list, dtype=np.float32) | |
| lgcn = np.stack(lgcn_list).astype(np.float32) | |
| self._orig_normed = orig / (np.linalg.norm(orig, axis=1, keepdims=True) + 1e-8) | |
| self._lgcn_normed = lgcn / (np.linalg.norm(lgcn, axis=1, keepdims=True) + 1e-8) | |
| print( | |
| f"HybridGraphRetriever ready: {len(node_ids)} nodes, " | |
| f"alpha={alpha}, top_k={top_k}" | |
| ) | |
| def _retrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: | |
| query_emb = np.array( | |
| self._embed_model.get_query_embedding(query_bundle.query_str), | |
| dtype=np.float32, | |
| ) | |
| query_normed = query_emb / (np.linalg.norm(query_emb) + 1e-8) | |
| sim_orig = self._orig_normed @ query_normed | |
| sim_lgcn = self._lgcn_normed @ query_normed | |
| # Weighted linear combination: score = alpha * sim_original + (1 - alpha) * sim_lightgcn | |
| scores = self._alpha * sim_orig + (1 - self._alpha) * sim_lgcn | |
| top_indices = np.argsort(scores)[::-1][: self._top_k] | |
| return [ | |
| NodeWithScore( | |
| node=TextNode( | |
| text=self._node_texts[i], | |
| id_=self._node_ids[i], | |
| metadata={"label": self._node_labels[i]}, | |
| ), | |
| score=float(scores[i]), | |
| ) | |
| for i in top_indices | |
| ] | |
| def _strip_think_tags(text: str) -> str: | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL) | |
| if "<think>" in text: | |
| text = text[: text.index("<think>")] | |
| return text.strip() | |
| def main() -> None: | |
| print("Loading embedding model...") | |
| embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL) | |
| print("Loading LLM...") | |
| model_path = download_gguf_model(MODEL_REPO) | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO) | |
| def messages_to_prompt(messages): | |
| messages = [{"role": m.role.value, "content": m.content} for m in messages] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| return prompt | |
| def completion_to_prompt(completion): | |
| messages = [{"role": "user", "content": completion}] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| return prompt | |
| llm = LlamaCPP( | |
| model_path=model_path, | |
| max_new_tokens=4000, | |
| context_window=16384, | |
| generate_kwargs={}, | |
| model_kwargs={"n_gpu_layers": -1}, | |
| messages_to_prompt=messages_to_prompt, | |
| completion_to_prompt=completion_to_prompt, | |
| verbose=True, | |
| ) | |
| Settings.embed_model = embed_model | |
| Settings.llm = llm | |
| print("Loading property graph index...") | |
| storage_context = StorageContext.from_defaults(persist_dir=DATA_DIR) | |
| index = load_index_from_storage(storage_context) | |
| print(f"Index loaded: {index.index_id}") | |
| chat_engine = index.as_chat_engine() | |
| print("Building hybrid retriever...") | |
| # retriever = HybridGraphRetriever( | |
| # data_dir=DATA_DIR, | |
| # embed_model=embed_model, | |
| # alpha=0.5, | |
| # top_k=5, | |
| # ) | |
| async def chat(message: str, history: list[dict]): | |
| # nodes = retriever.retrieve(message) | |
| # | |
| # context = "\n\n".join( | |
| # f"[{n.metadata.get('label', '')}] {n.text}" for n in nodes | |
| # ) | |
| # messages: list[dict[str, str]] = [ | |
| # { | |
| # "role": "system", | |
| # "content": ( | |
| # "You are a helpful knowledge assistant. " | |
| # # "Answer questions based on the provided context from a knowledge graph. " | |
| # # "If the context doesn't contain relevant information, say so.\n\n" | |
| # # f"Context:\n{context}" | |
| # ), | |
| # } | |
| # ] | |
| # for msg in history: | |
| # messages.append({"role": msg["role"], "content": msg["content"]}) | |
| # messages.append({"role": "user", "content": message}) | |
| output = await chat_engine.astream_chat(message) | |
| print("shit") | |
| async for shit in output.async_response_gen(): | |
| yield shit | |
| print("Starting Gradio app...") | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="Knowledge Graph Chat", | |
| description="Chat with an LLM powered by a knowledge graph with hybrid retrieval (original + LightGCN embeddings).", | |
| # type="messages", | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| ) | |