|
|
import os |
|
|
import pandas as pd |
|
|
import tiktoken |
|
|
|
|
|
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey |
|
|
from graphrag.query.indexer_adapters import ( |
|
|
read_indexer_covariates, |
|
|
read_indexer_entities, |
|
|
read_indexer_relationships, |
|
|
read_indexer_reports, |
|
|
read_indexer_text_units, |
|
|
) |
|
|
from graphrag.query.llm.oai.chat_openai import ChatOpenAI |
|
|
from graphrag.query.llm.oai.embedding import OpenAIEmbedding |
|
|
from graphrag.query.llm.oai.typing import OpenaiApiType |
|
|
from graphrag.query.question_gen.local_gen import LocalQuestionGen |
|
|
from graphrag.query.structured_search.local_search.mixed_context import ( |
|
|
LocalSearchMixedContext, |
|
|
) |
|
|
from graphrag.query.structured_search.local_search.search import LocalSearch |
|
|
from graphrag.vector_stores.lancedb import LanceDBVectorStore |
|
|
|
|
|
|
|
|
BASE_DATA_DIR = "/app/graphrag-data/data" |
|
|
|
|
|
|
|
|
DATA_CONFIGS = {} |
|
|
|
|
|
|
|
|
data_dirs = [d for d in os.listdir(BASE_DATA_DIR) if os.path.isdir(os.path.join(BASE_DATA_DIR, d))] |
|
|
|
|
|
|
|
|
for dir_name in data_dirs: |
|
|
DATA_CONFIGS[dir_name] = { |
|
|
"input_dir": os.path.join(BASE_DATA_DIR, dir_name), |
|
|
"community_level": 2 |
|
|
} |
|
|
|
|
|
api_key = os.environ['api_key'] |
|
|
llm_model = os.environ['llm_model'] |
|
|
embedding_model = os.environ['embedding_model'] |
|
|
api_base = os.environ['api_base'] |
|
|
|
|
|
llm = ChatOpenAI( |
|
|
api_key=api_key, |
|
|
api_base=api_base, |
|
|
model=llm_model, |
|
|
api_type=OpenaiApiType.OpenAI, |
|
|
max_retries=10, |
|
|
) |
|
|
|
|
|
token_encoder = tiktoken.get_encoding("cl100k_base") |
|
|
|
|
|
text_embedder = OpenAIEmbedding( |
|
|
api_key=api_key, |
|
|
api_base=api_base, |
|
|
api_type=OpenaiApiType.OpenAI, |
|
|
model=embedding_model, |
|
|
deployment_name=embedding_model, |
|
|
max_retries=7, |
|
|
) |
|
|
|
|
|
|
|
|
def load_data(input_dir, community_level): |
|
|
lancedb_uri = f"{input_dir}/lancedb" |
|
|
|
|
|
|
|
|
COMMUNITY_REPORT_TABLE = "create_final_community_reports" |
|
|
ENTITY_TABLE = "create_final_nodes" |
|
|
ENTITY_EMBEDDING_TABLE = "create_final_entities" |
|
|
RELATIONSHIP_TABLE = "create_final_relationships" |
|
|
TEXT_UNIT_TABLE = "create_final_text_units" |
|
|
|
|
|
|
|
|
entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet") |
|
|
entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet") |
|
|
entities = read_indexer_entities(entity_df, entity_embedding_df, community_level) |
|
|
|
|
|
|
|
|
description_embedding_store = LanceDBVectorStore( |
|
|
collection_name="default-entity-description", |
|
|
) |
|
|
description_embedding_store.connect(db_uri=lancedb_uri) |
|
|
|
|
|
relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet") |
|
|
relationships = read_indexer_relationships(relationship_df) |
|
|
|
|
|
report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet") |
|
|
reports = read_indexer_reports(report_df, entity_df, community_level) |
|
|
|
|
|
text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet") |
|
|
text_units = read_indexer_text_units(text_unit_df) |
|
|
|
|
|
return entities, description_embedding_store, relationships, reports, text_units |
|
|
|
|
|
|
|
|
search_engines = {} |
|
|
|
|
|
|
|
|
def initialize_search_engine(model_name): |
|
|
if model_name not in DATA_CONFIGS: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
config = DATA_CONFIGS[model_name] |
|
|
|
|
|
entities, description_embedding_store, relationships, reports, text_units = load_data( |
|
|
config["input_dir"], |
|
|
config["community_level"] |
|
|
) |
|
|
|
|
|
context_builder = LocalSearchMixedContext( |
|
|
community_reports=reports, |
|
|
text_units=text_units, |
|
|
entities=entities, |
|
|
relationships=relationships, |
|
|
covariates=None, |
|
|
entity_text_embeddings=description_embedding_store, |
|
|
embedding_vectorstore_key=EntityVectorStoreKey.ID, |
|
|
text_embedder=text_embedder, |
|
|
token_encoder=token_encoder, |
|
|
) |
|
|
|
|
|
local_context_params = { |
|
|
"text_unit_prop": 0.5, |
|
|
"community_prop": 0.1, |
|
|
"conversation_history_max_turns": 5, |
|
|
"conversation_history_user_turns_only": True, |
|
|
"top_k_mapped_entities": 10, |
|
|
"top_k_relationships": 10, |
|
|
"include_entity_rank": True, |
|
|
"include_relationship_weight": True, |
|
|
"include_community_rank": False, |
|
|
"return_candidate_context": False, |
|
|
"embedding_vectorstore_key": EntityVectorStoreKey.ID, |
|
|
"max_tokens": 36_000, |
|
|
} |
|
|
|
|
|
llm_params = get_llm_params() |
|
|
return create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params) |
|
|
|
|
|
|
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
import uvicorn |
|
|
from datetime import datetime |
|
|
import uuid |
|
|
import time |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
def get_llm_params(max_tokens=2000, temperature=0.0): |
|
|
return { |
|
|
"max_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
} |
|
|
|
|
|
def create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params): |
|
|
return LocalSearch( |
|
|
llm=llm, |
|
|
context_builder=context_builder, |
|
|
token_encoder=token_encoder, |
|
|
llm_params=llm_params, |
|
|
context_builder_params=local_context_params, |
|
|
response_type="multiple paragraphs", |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/completions") |
|
|
async def completions(request: Request): |
|
|
body = await request.json() |
|
|
|
|
|
prompt = body.get("prompt", "hi") |
|
|
max_tokens = body.get("max_tokens", 2000) |
|
|
temperature = body.get("temperature", 0.0) |
|
|
model = body.get("model", "ghost") |
|
|
|
|
|
|
|
|
if model not in search_engines: |
|
|
try: |
|
|
search_engines[model] = initialize_search_engine(model) |
|
|
except ValueError as e: |
|
|
return JSONResponse( |
|
|
content={"error": str(e)}, |
|
|
status_code=400 |
|
|
) |
|
|
|
|
|
search_engine = search_engines[model] |
|
|
llm_params = get_llm_params(max_tokens, temperature) |
|
|
search_engine.llm_params = llm_params |
|
|
|
|
|
if prompt == "hi" or prompt == "": |
|
|
result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}" |
|
|
result = type('obj', (), {'response': result_text})() |
|
|
else: |
|
|
result = await search_engine.asearch(prompt) |
|
|
|
|
|
|
|
|
prompt_tokens = len(prompt.split()) |
|
|
completion_tokens = len(result.response.split()) |
|
|
total_tokens = prompt_tokens + completion_tokens |
|
|
|
|
|
|
|
|
response = { |
|
|
"id": f"cmpl-{str(uuid.uuid4())[:8]}", |
|
|
"object": "text_completion", |
|
|
"created": int(time.time()), |
|
|
"model": model, |
|
|
"system_fingerprint": f"fp_{str(uuid.uuid4())[:8]}", |
|
|
"choices": [ |
|
|
{ |
|
|
"text": result.response, |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": prompt_tokens, |
|
|
"completion_tokens": completion_tokens, |
|
|
"total_tokens": total_tokens |
|
|
} |
|
|
} |
|
|
|
|
|
return JSONResponse(content=response) |
|
|
|
|
|
|
|
|
from fastapi.responses import StreamingResponse |
|
|
import json |
|
|
import asyncio |
|
|
|
|
|
@app.post("/api/v1/chat/completions") |
|
|
async def chat_completions(request: Request): |
|
|
body = await request.json() |
|
|
|
|
|
|
|
|
model = body.get("model", "ghost") |
|
|
messages = body.get("messages", []) |
|
|
temperature = body.get("temperature", 0.0) |
|
|
max_tokens = body.get("max_tokens", 2000) |
|
|
stream = body.get("stream", False) |
|
|
|
|
|
|
|
|
user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "") |
|
|
|
|
|
|
|
|
if model not in search_engines: |
|
|
try: |
|
|
search_engines[model] = initialize_search_engine(model) |
|
|
except ValueError as e: |
|
|
return JSONResponse( |
|
|
content={"error": str(e)}, |
|
|
status_code=400 |
|
|
) |
|
|
|
|
|
|
|
|
search_engine = search_engines[model] |
|
|
llm_params = get_llm_params(max_tokens, temperature) |
|
|
search_engine.llm_params = llm_params |
|
|
|
|
|
|
|
|
if user_message == "" or user_message == "hi": |
|
|
result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}" |
|
|
result = type('obj', (), {'response': result_text})() |
|
|
else: |
|
|
|
|
|
result = await search_engine.asearch(user_message) |
|
|
|
|
|
if not stream: |
|
|
|
|
|
|
|
|
prompt_tokens = len(user_message.split()) |
|
|
completion_tokens = len(result.response.split()) |
|
|
total_tokens = prompt_tokens + completion_tokens |
|
|
|
|
|
completion_tokens_details = { |
|
|
"reasoning_tokens": 0, |
|
|
"accepted_prediction_tokens": 0, |
|
|
"rejected_prediction_tokens": 0 |
|
|
} |
|
|
|
|
|
response = { |
|
|
"id": f"chatcmpl-{str(uuid.uuid4())[:8]}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": model, |
|
|
"usage": { |
|
|
"prompt_tokens": prompt_tokens, |
|
|
"completion_tokens": completion_tokens, |
|
|
"total_tokens": total_tokens, |
|
|
"completion_tokens_details": completion_tokens_details |
|
|
}, |
|
|
"choices": [ |
|
|
{ |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": result.response |
|
|
}, |
|
|
"logprobs": None, |
|
|
"finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop", |
|
|
"index": 0 |
|
|
} |
|
|
] |
|
|
} |
|
|
return JSONResponse(content=response) |
|
|
|
|
|
async def stream_response(): |
|
|
chat_id = f"chatcmpl-{str(uuid.uuid4())[:8]}" |
|
|
system_fingerprint = f"fp_{str(uuid.uuid4())[:8]}" |
|
|
timestamp = int(time.time()) |
|
|
|
|
|
|
|
|
first_chunk = { |
|
|
'id': chat_id, |
|
|
'object': 'chat.completion.chunk', |
|
|
'created': timestamp, |
|
|
'model': model, |
|
|
'system_fingerprint': system_fingerprint, |
|
|
'choices': [{ |
|
|
'index': 0, |
|
|
'delta': {'role': 'assistant'}, |
|
|
'logprobs': None, |
|
|
'finish_reason': None |
|
|
}] |
|
|
} |
|
|
yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
|
text = result.response |
|
|
chunk_size = 50 |
|
|
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
|
|
|
for chunk in chunks: |
|
|
data = { |
|
|
'id': chat_id, |
|
|
'object': 'chat.completion.chunk', |
|
|
'created': timestamp, |
|
|
'model': model, |
|
|
'system_fingerprint': system_fingerprint, |
|
|
'choices': [{ |
|
|
'index': 0, |
|
|
'delta': {'content': chunk}, |
|
|
'logprobs': None, |
|
|
'finish_reason': None |
|
|
}] |
|
|
} |
|
|
|
|
|
json_str = json.dumps(data, ensure_ascii=False) |
|
|
yield f"data: {json_str}\n\n" |
|
|
await asyncio.sleep(0.1) |
|
|
|
|
|
|
|
|
final_chunk = { |
|
|
'id': chat_id, |
|
|
'object': 'chat.completion.chunk', |
|
|
'created': timestamp, |
|
|
'model': model, |
|
|
'system_fingerprint': system_fingerprint, |
|
|
'choices': [{ |
|
|
'index': 0, |
|
|
'delta': {}, |
|
|
'logprobs': None, |
|
|
'finish_reason': 'stop' |
|
|
}] |
|
|
} |
|
|
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" |
|
|
yield 'data: [DONE]\n\n' |
|
|
|
|
|
return StreamingResponse( |
|
|
stream_response(), |
|
|
media_type='text/event-stream' |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return "Hello from Docker!" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8080) |
|
|
|
|
|
|
|
|
|