thng292's picture
Added readme
79a7a69
import json
import os
import re
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download, list_repo_files
from llama_index.core import (
QueryBundle,
Settings,
StorageContext,
load_index_from_storage,
)
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.llama_cpp import LlamaCPP
from transformers import AutoTokenizer
DATA_DIR = os.environ.get("DATA_DIR", "packages/data_prep/generated")
MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF"
TOKENIZER_REPO = "Qwen/Qwen3.5-4B"
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
def download_gguf_model(repo_id: str) -> str:
files = list_repo_files(repo_id)
gguf_files = [f for f in files if f.endswith(".gguf")]
target = next((f for f in gguf_files if "Q4_K_M" in f.upper()), gguf_files[0])
print(f"Downloading {target} from {repo_id}...")
return hf_hub_download(repo_id=repo_id, filename=target)
class HybridGraphRetriever(BaseRetriever):
"""Combines original embedding similarity and LightGCN-enhanced embedding
similarity using a weighted linear formulation for hybrid ranking."""
def __init__(
self,
data_dir: str,
embed_model: HuggingFaceEmbedding,
alpha: float = 0.5,
top_k: int = 5,
):
super().__init__()
self._embed_model = embed_model
self._alpha = alpha
self._top_k = top_k
with open(os.path.join(data_dir, "property_graph_store.json")) as f:
pg_data = json.load(f)
with open(os.path.join(data_dir, "id_to_int.json")) as f:
id_to_int = json.load(f)
lightgcn_all = np.load(os.path.join(data_dir, "lightgcn_embeddings.npy"))
node_ids: list[str] = []
node_texts: list[str] = []
node_labels: list[str] = []
orig_list: list[list[float]] = []
lgcn_list: list[np.ndarray] = []
for node_id, node_data in pg_data["nodes"].items():
if node_id not in id_to_int:
continue
emb = node_data.get("embedding")
if not emb:
continue
idx = id_to_int[node_id]
node_ids.append(node_id)
node_texts.append(node_data.get("text", ""))
node_labels.append(node_data.get("label", ""))
orig_list.append(emb)
lgcn_list.append(lightgcn_all[idx])
self._node_ids = node_ids
self._node_texts = node_texts
self._node_labels = node_labels
orig = np.array(orig_list, dtype=np.float32)
lgcn = np.stack(lgcn_list).astype(np.float32)
self._orig_normed = orig / (np.linalg.norm(orig, axis=1, keepdims=True) + 1e-8)
self._lgcn_normed = lgcn / (np.linalg.norm(lgcn, axis=1, keepdims=True) + 1e-8)
print(
f"HybridGraphRetriever ready: {len(node_ids)} nodes, "
f"alpha={alpha}, top_k={top_k}"
)
def _retrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]:
query_emb = np.array(
self._embed_model.get_query_embedding(query_bundle.query_str),
dtype=np.float32,
)
query_normed = query_emb / (np.linalg.norm(query_emb) + 1e-8)
sim_orig = self._orig_normed @ query_normed
sim_lgcn = self._lgcn_normed @ query_normed
# Weighted linear combination: score = alpha * sim_original + (1 - alpha) * sim_lightgcn
scores = self._alpha * sim_orig + (1 - self._alpha) * sim_lgcn
top_indices = np.argsort(scores)[::-1][: self._top_k]
return [
NodeWithScore(
node=TextNode(
text=self._node_texts[i],
id_=self._node_ids[i],
metadata={"label": self._node_labels[i]},
),
score=float(scores[i]),
)
for i in top_indices
]
def _strip_think_tags(text: str) -> str:
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
if "<think>" in text:
text = text[: text.index("<think>")]
return text.strip()
def main() -> None:
print("Loading embedding model...")
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
print("Loading LLM...")
model_path = download_gguf_model(MODEL_REPO)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO)
def messages_to_prompt(messages):
messages = [{"role": m.role.value, "content": m.content} for m in messages]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
def completion_to_prompt(completion):
messages = [{"role": "user", "content": completion}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
llm = LlamaCPP(
model_path=model_path,
max_new_tokens=4000,
context_window=16384,
generate_kwargs={},
model_kwargs={"n_gpu_layers": -1},
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
verbose=True,
)
Settings.embed_model = embed_model
Settings.llm = llm
print("Loading property graph index...")
storage_context = StorageContext.from_defaults(persist_dir=DATA_DIR)
index = load_index_from_storage(storage_context)
print(f"Index loaded: {index.index_id}")
chat_engine = index.as_chat_engine()
print("Building hybrid retriever...")
# retriever = HybridGraphRetriever(
# data_dir=DATA_DIR,
# embed_model=embed_model,
# alpha=0.5,
# top_k=5,
# )
async def chat(message: str, history: list[dict]):
# nodes = retriever.retrieve(message)
#
# context = "\n\n".join(
# f"[{n.metadata.get('label', '')}] {n.text}" for n in nodes
# )
# messages: list[dict[str, str]] = [
# {
# "role": "system",
# "content": (
# "You are a helpful knowledge assistant. "
# # "Answer questions based on the provided context from a knowledge graph. "
# # "If the context doesn't contain relevant information, say so.\n\n"
# # f"Context:\n{context}"
# ),
# }
# ]
# for msg in history:
# messages.append({"role": msg["role"], "content": msg["content"]})
# messages.append({"role": "user", "content": message})
output = await chat_engine.astream_chat(message)
print("shit")
async for shit in output.async_response_gen():
yield shit
print("Starting Gradio app...")
demo = gr.ChatInterface(
fn=chat,
title="Knowledge Graph Chat",
description="Chat with an LLM powered by a knowledge graph with hybrid retrieval (original + LightGCN embeddings).",
# type="messages",
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
)