metalmind / kg_sys /llm.py
IELTS8's picture
Upload folder using huggingface_hub
ada3f28 verified
import logging
from langchain.docstore.document import Document
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_google_vertexai import ChatVertexAI
from langchain_groq import ChatGroq
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from langchain_experimental.graph_transformers import LLMGraphTransformer
# from langchain_anthropic import ChatAnthropic
# from langchain_fireworks import ChatFireworks
# from langchain_aws import ChatBedrock
# from langchain_community.chat_models import ChatOllama
import boto3
import google.auth
from kg_sys.shared.constants import MODEL_VERSIONS
def get_llm(model: str):
"""Retrieve the specified language model based on the model name."""
env_key = "LLM_MODEL_CONFIG_" + model
env_value = os.environ.get(env_key)
logging.info("Model: {}".format(env_key))
if "gemini" in model:
credentials, project_id = google.auth.default()
model_name = MODEL_VERSIONS[model]
llm = ChatVertexAI(
model_name=model_name,
convert_system_message_to_human=True,
credentials=credentials,
project=project_id,
temperature=0,
safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
},
)
elif "openai" in model:
model_name = MODEL_VERSIONS[model]
llm = ChatOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
model=model_name,
temperature=0,
)
elif "azure" in model:
model_name, api_endpoint, api_key, api_version = env_value.split(",")
llm = AzureChatOpenAI(
api_key=api_key,
azure_endpoint=api_endpoint,
azure_deployment=model_name, # takes precedence over model parameter
api_version=api_version,
temperature=0,
max_tokens=None,
timeout=None,
)
elif "anthropic" in model:
model_name, api_key = env_value.split(",")
# llm = ChatAnthropic(
# api_key=api_key, model=model_name, temperature=0, timeout=None
# # )
elif "fireworks" in model:
model_name, api_key = env_value.split(",")
# llm = ChatFireworks(api_key=api_key, model=model_name)
elif "groq" in model:
model_name, base_url, api_key = env_value.split(",")
llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0)
elif "bedrock" in model:
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name=region_name,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
)
# llm = ChatBedrock(
# client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
# )
elif "ollama" in model:
model_name, base_url = env_value.split(",")
# llm = ChatOllama(base_url=base_url, model=model_name)
elif "diffbot" in model:
model_name = "diffbot"
llm = DiffbotGraphTransformer(
diffbot_api_key=os.environ.get("DIFFBOT_API_KEY"),
extract_types=["entities", "facts"],
)
else:
model_name, api_endpoint, api_key = env_value.split(",")
llm = ChatOpenAI(
api_key=api_key,
base_url=api_endpoint,
model=model_name,
temperature=0,
)
logging.info(f"Model created - Model Version: {model}")
return llm, model_name
def get_combined_chunks(chunkId_chunkDoc_list):
chunks_to_combine = int(os.environ.get("NUMBER_OF_CHUNKS_TO_COMBINE"))
logging.info(f"Combining {chunks_to_combine} chunks before sending request to LLM")
combined_chunk_document_list = []
combined_chunks_page_content = [
"".join(
document["chunk_doc"].page_content
for document in chunkId_chunkDoc_list[i: i + chunks_to_combine]
)
for i in range(0, len(chunkId_chunkDoc_list), chunks_to_combine)
]
combined_chunks_ids = [
[
document["chunk_id"]
for document in chunkId_chunkDoc_list[i: i + chunks_to_combine]
]
for i in range(0, len(chunkId_chunkDoc_list), chunks_to_combine)
]
for i in range(len(combined_chunks_page_content)):
combined_chunk_document_list.append(
Document(
page_content=combined_chunks_page_content[i],
metadata={"combined_chunk_ids": combined_chunks_ids[i]},
)
)
return combined_chunk_document_list
def get_graph_document_list(
llm, combined_chunk_document_list, allowedNodes, allowedRelationship
):
futures = []
graph_document_list = []
if "diffbot_api_key" in dir(llm):
llm_transformer = llm
else:
if "get_name" in dir(llm) and llm.get_name() == "ChatOllama":
node_properties = False
else:
node_properties = ["description"]
GRAPH_PROMPT = open("prompt/build_graph_prompt.txt").read()
system_prompt = (
"# Knowledge Graph Instructions for GPT-4o\n"
"## 1. Overview\n"
"You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.\n"
"Try to capture as much information from the text as possible without "
"sacrificing accuracy. Do not add any information that is not explicitly "
"mentioned in the text.\n"
"- **Nodes** represent entities and concepts.\n"
"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
"accessible for a vast audience.\n"
"## 2. Labeling Nodes\n"
"- **Consistency**: Ensure you use available types for node labels.\n"
"Ensure you use basic or elementary types for node labels.\n"
"- For example, when you identify an entity representing a person, "
"always label it as **'person'**. Avoid using more specific terms "
"like 'mathematician' or 'scientist'."
"- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
"names or human-readable identifiers found in the text.\n"
"- **Relationships** represent connections between entities or concepts.\n"
"Ensure consistency and generality in relationship types when constructing "
"knowledge graphs. Instead of using specific and momentary types "
"such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
"## 3. Label Determination\n"
"When the entities are like Metal Powder Bottle, Build Plate, Chiller, Change Filter, they should be labeled as Components. "
"## 3. Co-reference Resolution\n"
"- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.\n"
'If an entity, such as "Renishaw AM250/AM400", is mentioned multiple times in the text '
'but is referred to by different names or pronouns (e.g., "Renishaw AM250", "AM400", "AM250/AM400") and should be recognized as label System,'
"always use the most complete identifier for that entity throughout the "
'knowledge graph. In this example, use "Renishaw AM250/AM400" as the entity ID.\n'
"Remember, the knowledge graph should be coherent and easily understandable, "
"so maintaining consistency in entity references is crucial.\n"
"## 4. Cases with Figure URL \n"
"When an entity is linked an external link, usually an image, ensure to paired with a URL link as feature named 'img_ref'. For instance,"
"![](https://cdn.mathpix.com/cropped/2024_10_07_94a733c63eb0c0e0ed64g-12.jpg?height=595&width=1600&top_left_y=1416&top_left_x=228) \n"
"Figure 2 Software welcome screen - AM250 (I) and AM400 (r). \n"
"The entity can have label 'Figure', id 'Figure 2', description 'Software welcome screen - AM250 (I) and AM400 (r)',"
" and a key named 'img_ref' with value 'https://cdn.mathpix.com/cropped/2024_10_07_94a733c63eb0c0e0ed64g-12.jpg?height=595&width=1600&top_left_y=1416&top_left_x=228'."
"## 5. Strict Compliance\n"
"Adhere to the rules strictly. Non-compliance will result in termination."
)
prompt_metal_am = ChatPromptTemplate.from_messages(
[
(
"system",
GRAPH_PROMPT,
),
(
"human",
(
"Tip: Make sure to answer in the correct format and do "
"not include any explanations. "
"Use the given format to extract information from the "
"following input: {input}"
),
),
]
)
llm_transformer = LLMGraphTransformer(
llm=llm,
prompt=prompt_metal_am,
node_properties=node_properties,
allowed_nodes=allowedNodes,
allowed_relationships=allowedRelationship,
)
with ThreadPoolExecutor(max_workers=10) as executor:
for chunk in combined_chunk_document_list:
chunk_doc = Document(
page_content=chunk.page_content.encode("utf-8"), metadata=chunk.metadata
)
futures.append(
executor.submit(llm_transformer.convert_to_graph_documents, [chunk_doc])
)
for i, future in enumerate(concurrent.futures.as_completed(futures)):
graph_document = future.result()
graph_document_list.append(graph_document[0])
return graph_document_list
def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship):
llm, model_name = get_llm(model)
combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list)
if allowedNodes is None or allowedNodes == "":
allowedNodes = []
else:
allowedNodes = allowedNodes.split(',')
if allowedRelationship is None or allowedRelationship == "":
allowedRelationship = []
else:
allowedRelationship = allowedRelationship.split(',')
graph_document_list = get_graph_document_list(
llm, combined_chunk_document_list, allowedNodes, allowedRelationship
)
return graph_document_list