test / server.py
nameliu's picture
Update server.py
6752c01 verified
import os
import pandas as pd
import tiktoken
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
# 基础数据目录路径
BASE_DATA_DIR = "/app/graphrag-data/data"
# 初始化DATA_CONFIGS字典
DATA_CONFIGS = {}
# 获取目录下所有文件夹
data_dirs = [d for d in os.listdir(BASE_DATA_DIR) if os.path.isdir(os.path.join(BASE_DATA_DIR, d))]
# 为每个文件夹创建配置
for dir_name in data_dirs:
DATA_CONFIGS[dir_name] = {
"input_dir": os.path.join(BASE_DATA_DIR, dir_name),
"community_level": 2 # 默认值设置为2
}
api_key = os.environ['api_key']
llm_model = os.environ['llm_model']
embedding_model = os.environ['embedding_model']
api_base = os.environ['api_base']
llm = ChatOpenAI(
api_key=api_key,
api_base=api_base,
model=llm_model,
api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
max_retries=10,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
text_embedder = OpenAIEmbedding(
api_key=api_key,
api_base=api_base,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=7,
)
# 将数据加载逻辑封装成函数
def load_data(input_dir, community_level):
lancedb_uri = f"{input_dir}/lancedb"
# 定义表名
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
TEXT_UNIT_TABLE = "create_final_text_units"
# 读取数据
entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet")
entities = read_indexer_entities(entity_df, entity_embedding_df, community_level)
# 创建向量存储
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=lancedb_uri)
relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, community_level)
text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
return entities, description_embedding_store, relationships, reports, text_units
# 创建缓存字典来存储不同模型的搜索引擎实例
search_engines = {}
# 初始化函数
def initialize_search_engine(model_name):
if model_name not in DATA_CONFIGS:
raise ValueError(f"Unknown model: {model_name}")
config = DATA_CONFIGS[model_name]
# print(config)
entities, description_embedding_store, relationships, reports, text_units = load_data(
config["input_dir"],
config["community_level"]
)
context_builder = LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=None,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID,
text_embedder=text_embedder,
token_encoder=token_encoder,
)
local_context_params = {
"text_unit_prop": 0.5,
"community_prop": 0.1,
"conversation_history_max_turns": 5,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": 10,
"top_k_relationships": 10,
"include_entity_rank": True,
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
"max_tokens": 36_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
}
llm_params = get_llm_params()
return create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params)
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
from datetime import datetime
import uuid
import time
app = FastAPI()
# 修改llm_params为动态配置
def get_llm_params(max_tokens=2000, temperature=0.0):
return {
"max_tokens": max_tokens,
"temperature": temperature,
}
def create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params):
return LocalSearch(
llm=llm,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs",
)
@app.post("/v1/completions")
async def completions(request: Request):
body = await request.json()
prompt = body.get("prompt", "hi")
max_tokens = body.get("max_tokens", 2000)
temperature = body.get("temperature", 0.0)
model = body.get("model", "ghost") # 默认使用ghost
# 检查模型是否已初始化
if model not in search_engines:
try:
search_engines[model] = initialize_search_engine(model)
except ValueError as e:
return JSONResponse(
content={"error": str(e)},
status_code=400
)
search_engine = search_engines[model]
llm_params = get_llm_params(max_tokens, temperature)
search_engine.llm_params = llm_params # 更新LLM参数
if prompt == "hi" or prompt == "":
result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}"
result = type('obj', (), {'response': result_text})()
else:
result = await search_engine.asearch(prompt)
# 计算token使用情况(这里需要根据你的实际token计算方法进行修改)
prompt_tokens = len(prompt.split()) # 简单示例,实际应使用proper tokenizer
completion_tokens = len(result.response.split())
total_tokens = prompt_tokens + completion_tokens
# 构建响应
response = {
"id": f"cmpl-{str(uuid.uuid4())[:8]}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"system_fingerprint": f"fp_{str(uuid.uuid4())[:8]}",
"choices": [
{
"text": result.response,
"index": 0,
"logprobs": None,
"finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
}
return JSONResponse(content=response)
from fastapi.responses import StreamingResponse
import json
import asyncio
@app.post("/api/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
# Extracting parameters from request body
model = body.get("model", "ghost") # Default model
messages = body.get("messages", [])
temperature = body.get("temperature", 0.0)
max_tokens = body.get("max_tokens", 2000)
stream = body.get("stream", False) # 获取stream参数
# Extracting user's prompt from messages
user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "")
# Check if the model exists in initialized search engines
if model not in search_engines:
try:
search_engines[model] = initialize_search_engine(model)
except ValueError as e:
return JSONResponse(
content={"error": str(e)},
status_code=400
)
# Initialize search engine and LLM parameters
search_engine = search_engines[model]
llm_params = get_llm_params(max_tokens, temperature)
search_engine.llm_params = llm_params
# Handle 'empty' prompts to list available models
if user_message == "" or user_message == "hi":
result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}"
result = type('obj', (), {'response': result_text})()
else:
# Fetch completions from search engine
result = await search_engine.asearch(user_message)
if not stream:
# 非流式响应,返回完整的响应
# Token usage calculation
prompt_tokens = len(user_message.split())
completion_tokens = len(result.response.split())
total_tokens = prompt_tokens + completion_tokens
completion_tokens_details = {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
response = {
"id": f"chatcmpl-{str(uuid.uuid4())[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"completion_tokens_details": completion_tokens_details
},
"choices": [
{
"message": {
"role": "assistant",
"content": result.response
},
"logprobs": None,
"finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop",
"index": 0
}
]
}
return JSONResponse(content=response)
async def stream_response():
chat_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
system_fingerprint = f"fp_{str(uuid.uuid4())[:8]}"
timestamp = int(time.time())
# 发送role消息
first_chunk = {
'id': chat_id,
'object': 'chat.completion.chunk',
'created': timestamp,
'model': model,
'system_fingerprint': system_fingerprint,
'choices': [{
'index': 0,
'delta': {'role': 'assistant'},
'logprobs': None,
'finish_reason': None
}]
}
yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n"
# 将文本分成较大的块(每块约10个字符)
text = result.response
chunk_size = 50
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
for chunk in chunks:
data = {
'id': chat_id,
'object': 'chat.completion.chunk',
'created': timestamp,
'model': model,
'system_fingerprint': system_fingerprint,
'choices': [{
'index': 0,
'delta': {'content': chunk},
'logprobs': None,
'finish_reason': None
}]
}
# 使用 ensure_ascii=False 确保中文正确显示
json_str = json.dumps(data, ensure_ascii=False)
yield f"data: {json_str}\n\n"
await asyncio.sleep(0.1) # 控制输出速度
# 发送结束消息
final_chunk = {
'id': chat_id,
'object': 'chat.completion.chunk',
'created': timestamp,
'model': model,
'system_fingerprint': system_fingerprint,
'choices': [{
'index': 0,
'delta': {},
'logprobs': None,
'finish_reason': 'stop'
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
yield 'data: [DONE]\n\n'
return StreamingResponse(
stream_response(),
media_type='text/event-stream'
)
@app.get("/")
async def root():
return "Hello from Docker!"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8080)