import json import os import re import gradio as gr import numpy as np from huggingface_hub import hf_hub_download, list_repo_files from llama_index.core import ( QueryBundle, Settings, StorageContext, load_index_from_storage, ) from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore, TextNode from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.llama_cpp import LlamaCPP from transformers import AutoTokenizer DATA_DIR = os.environ.get("DATA_DIR", "packages/data_prep/generated") MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF" TOKENIZER_REPO = "Qwen/Qwen3.5-4B" EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" def download_gguf_model(repo_id: str) -> str: files = list_repo_files(repo_id) gguf_files = [f for f in files if f.endswith(".gguf")] target = next((f for f in gguf_files if "Q4_K_M" in f.upper()), gguf_files[0]) print(f"Downloading {target} from {repo_id}...") return hf_hub_download(repo_id=repo_id, filename=target) class HybridGraphRetriever(BaseRetriever): """Combines original embedding similarity and LightGCN-enhanced embedding similarity using a weighted linear formulation for hybrid ranking.""" def __init__( self, data_dir: str, embed_model: HuggingFaceEmbedding, alpha: float = 0.5, top_k: int = 5, ): super().__init__() self._embed_model = embed_model self._alpha = alpha self._top_k = top_k with open(os.path.join(data_dir, "property_graph_store.json")) as f: pg_data = json.load(f) with open(os.path.join(data_dir, "id_to_int.json")) as f: id_to_int = json.load(f) lightgcn_all = np.load(os.path.join(data_dir, "lightgcn_embeddings.npy")) node_ids: list[str] = [] node_texts: list[str] = [] node_labels: list[str] = [] orig_list: list[list[float]] = [] lgcn_list: list[np.ndarray] = [] for node_id, node_data in pg_data["nodes"].items(): if node_id not in id_to_int: continue emb = node_data.get("embedding") if not emb: continue idx = id_to_int[node_id] node_ids.append(node_id) node_texts.append(node_data.get("text", "")) node_labels.append(node_data.get("label", "")) orig_list.append(emb) lgcn_list.append(lightgcn_all[idx]) self._node_ids = node_ids self._node_texts = node_texts self._node_labels = node_labels orig = np.array(orig_list, dtype=np.float32) lgcn = np.stack(lgcn_list).astype(np.float32) self._orig_normed = orig / (np.linalg.norm(orig, axis=1, keepdims=True) + 1e-8) self._lgcn_normed = lgcn / (np.linalg.norm(lgcn, axis=1, keepdims=True) + 1e-8) print( f"HybridGraphRetriever ready: {len(node_ids)} nodes, " f"alpha={alpha}, top_k={top_k}" ) def _retrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: query_emb = np.array( self._embed_model.get_query_embedding(query_bundle.query_str), dtype=np.float32, ) query_normed = query_emb / (np.linalg.norm(query_emb) + 1e-8) sim_orig = self._orig_normed @ query_normed sim_lgcn = self._lgcn_normed @ query_normed # Weighted linear combination: score = alpha * sim_original + (1 - alpha) * sim_lightgcn scores = self._alpha * sim_orig + (1 - self._alpha) * sim_lgcn top_indices = np.argsort(scores)[::-1][: self._top_k] return [ NodeWithScore( node=TextNode( text=self._node_texts[i], id_=self._node_ids[i], metadata={"label": self._node_labels[i]}, ), score=float(scores[i]), ) for i in top_indices ] def _strip_think_tags(text: str) -> str: text = re.sub(r".*?", "", text, flags=re.DOTALL) if "" in text: text = text[: text.index("")] return text.strip() def main() -> None: print("Loading embedding model...") embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL) print("Loading LLM...") model_path = download_gguf_model(MODEL_REPO) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO) def messages_to_prompt(messages): messages = [{"role": m.role.value, "content": m.content} for m in messages] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt def completion_to_prompt(completion): messages = [{"role": "user", "content": completion}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt llm = LlamaCPP( model_path=model_path, max_new_tokens=4000, context_window=16384, generate_kwargs={}, model_kwargs={"n_gpu_layers": -1}, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, verbose=True, ) Settings.embed_model = embed_model Settings.llm = llm print("Loading property graph index...") storage_context = StorageContext.from_defaults(persist_dir=DATA_DIR) index = load_index_from_storage(storage_context) print(f"Index loaded: {index.index_id}") chat_engine = index.as_chat_engine() print("Building hybrid retriever...") # retriever = HybridGraphRetriever( # data_dir=DATA_DIR, # embed_model=embed_model, # alpha=0.5, # top_k=5, # ) async def chat(message: str, history: list[dict]): # nodes = retriever.retrieve(message) # # context = "\n\n".join( # f"[{n.metadata.get('label', '')}] {n.text}" for n in nodes # ) # messages: list[dict[str, str]] = [ # { # "role": "system", # "content": ( # "You are a helpful knowledge assistant. " # # "Answer questions based on the provided context from a knowledge graph. " # # "If the context doesn't contain relevant information, say so.\n\n" # # f"Context:\n{context}" # ), # } # ] # for msg in history: # messages.append({"role": msg["role"], "content": msg["content"]}) # messages.append({"role": "user", "content": message}) output = await chat_engine.astream_chat(message) print("shit") async for shit in output.async_response_gen(): yield shit print("Starting Gradio app...") demo = gr.ChatInterface( fn=chat, title="Knowledge Graph Chat", description="Chat with an LLM powered by a knowledge graph with hybrid retrieval (original + LightGCN embeddings).", # type="messages", ) demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, )