import os import pandas as pd import tiktoken from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.indexer_adapters import ( read_indexer_covariates, read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, ) from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType from graphrag.query.question_gen.local_gen import LocalQuestionGen from graphrag.query.structured_search.local_search.mixed_context import ( LocalSearchMixedContext, ) from graphrag.query.structured_search.local_search.search import LocalSearch from graphrag.vector_stores.lancedb import LanceDBVectorStore # 基础数据目录路径 BASE_DATA_DIR = "/app/graphrag-data/data" # 初始化DATA_CONFIGS字典 DATA_CONFIGS = {} # 获取目录下所有文件夹 data_dirs = [d for d in os.listdir(BASE_DATA_DIR) if os.path.isdir(os.path.join(BASE_DATA_DIR, d))] # 为每个文件夹创建配置 for dir_name in data_dirs: DATA_CONFIGS[dir_name] = { "input_dir": os.path.join(BASE_DATA_DIR, dir_name), "community_level": 2 # 默认值设置为2 } api_key = os.environ['api_key'] llm_model = os.environ['llm_model'] embedding_model = os.environ['embedding_model'] api_base = os.environ['api_base'] llm = ChatOpenAI( api_key=api_key, api_base=api_base, model=llm_model, api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI max_retries=10, ) token_encoder = tiktoken.get_encoding("cl100k_base") text_embedder = OpenAIEmbedding( api_key=api_key, api_base=api_base, api_type=OpenaiApiType.OpenAI, model=embedding_model, deployment_name=embedding_model, max_retries=7, ) # 将数据加载逻辑封装成函数 def load_data(input_dir, community_level): lancedb_uri = f"{input_dir}/lancedb" # 定义表名 COMMUNITY_REPORT_TABLE = "create_final_community_reports" ENTITY_TABLE = "create_final_nodes" ENTITY_EMBEDDING_TABLE = "create_final_entities" RELATIONSHIP_TABLE = "create_final_relationships" TEXT_UNIT_TABLE = "create_final_text_units" # 读取数据 entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet") entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet") entities = read_indexer_entities(entity_df, entity_embedding_df, community_level) # 创建向量存储 description_embedding_store = LanceDBVectorStore( collection_name="default-entity-description", ) description_embedding_store.connect(db_uri=lancedb_uri) relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet") relationships = read_indexer_relationships(relationship_df) report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet") reports = read_indexer_reports(report_df, entity_df, community_level) text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet") text_units = read_indexer_text_units(text_unit_df) return entities, description_embedding_store, relationships, reports, text_units # 创建缓存字典来存储不同模型的搜索引擎实例 search_engines = {} # 初始化函数 def initialize_search_engine(model_name): if model_name not in DATA_CONFIGS: raise ValueError(f"Unknown model: {model_name}") config = DATA_CONFIGS[model_name] # print(config) entities, description_embedding_store, relationships, reports, text_units = load_data( config["input_dir"], config["community_level"] ) context_builder = LocalSearchMixedContext( community_reports=reports, text_units=text_units, entities=entities, relationships=relationships, covariates=None, entity_text_embeddings=description_embedding_store, embedding_vectorstore_key=EntityVectorStoreKey.ID, text_embedder=text_embedder, token_encoder=token_encoder, ) local_context_params = { "text_unit_prop": 0.5, "community_prop": 0.1, "conversation_history_max_turns": 5, "conversation_history_user_turns_only": True, "top_k_mapped_entities": 10, "top_k_relationships": 10, "include_entity_rank": True, "include_relationship_weight": True, "include_community_rank": False, "return_candidate_context": False, "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids "max_tokens": 36_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) } llm_params = get_llm_params() return create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params) from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn from datetime import datetime import uuid import time app = FastAPI() # 修改llm_params为动态配置 def get_llm_params(max_tokens=2000, temperature=0.0): return { "max_tokens": max_tokens, "temperature": temperature, } def create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params): return LocalSearch( llm=llm, context_builder=context_builder, token_encoder=token_encoder, llm_params=llm_params, context_builder_params=local_context_params, response_type="multiple paragraphs", ) @app.post("/v1/completions") async def completions(request: Request): body = await request.json() prompt = body.get("prompt", "hi") max_tokens = body.get("max_tokens", 2000) temperature = body.get("temperature", 0.0) model = body.get("model", "ghost") # 默认使用ghost # 检查模型是否已初始化 if model not in search_engines: try: search_engines[model] = initialize_search_engine(model) except ValueError as e: return JSONResponse( content={"error": str(e)}, status_code=400 ) search_engine = search_engines[model] llm_params = get_llm_params(max_tokens, temperature) search_engine.llm_params = llm_params # 更新LLM参数 if prompt == "hi" or prompt == "": result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}" result = type('obj', (), {'response': result_text})() else: result = await search_engine.asearch(prompt) # 计算token使用情况(这里需要根据你的实际token计算方法进行修改) prompt_tokens = len(prompt.split()) # 简单示例,实际应使用proper tokenizer completion_tokens = len(result.response.split()) total_tokens = prompt_tokens + completion_tokens # 构建响应 response = { "id": f"cmpl-{str(uuid.uuid4())[:8]}", "object": "text_completion", "created": int(time.time()), "model": model, "system_fingerprint": f"fp_{str(uuid.uuid4())[:8]}", "choices": [ { "text": result.response, "index": 0, "logprobs": None, "finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop" } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens } } return JSONResponse(content=response) from fastapi.responses import StreamingResponse import json import asyncio @app.post("/api/v1/chat/completions") async def chat_completions(request: Request): body = await request.json() # Extracting parameters from request body model = body.get("model", "ghost") # Default model messages = body.get("messages", []) temperature = body.get("temperature", 0.0) max_tokens = body.get("max_tokens", 2000) stream = body.get("stream", False) # 获取stream参数 # Extracting user's prompt from messages user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "") # Check if the model exists in initialized search engines if model not in search_engines: try: search_engines[model] = initialize_search_engine(model) except ValueError as e: return JSONResponse( content={"error": str(e)}, status_code=400 ) # Initialize search engine and LLM parameters search_engine = search_engines[model] llm_params = get_llm_params(max_tokens, temperature) search_engine.llm_params = llm_params # Handle 'empty' prompts to list available models if user_message == "" or user_message == "hi": result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}" result = type('obj', (), {'response': result_text})() else: # Fetch completions from search engine result = await search_engine.asearch(user_message) if not stream: # 非流式响应,返回完整的响应 # Token usage calculation prompt_tokens = len(user_message.split()) completion_tokens = len(result.response.split()) total_tokens = prompt_tokens + completion_tokens completion_tokens_details = { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } response = { "id": f"chatcmpl-{str(uuid.uuid4())[:8]}", "object": "chat.completion", "created": int(time.time()), "model": model, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "completion_tokens_details": completion_tokens_details }, "choices": [ { "message": { "role": "assistant", "content": result.response }, "logprobs": None, "finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop", "index": 0 } ] } return JSONResponse(content=response) async def stream_response(): chat_id = f"chatcmpl-{str(uuid.uuid4())[:8]}" system_fingerprint = f"fp_{str(uuid.uuid4())[:8]}" timestamp = int(time.time()) # 发送role消息 first_chunk = { 'id': chat_id, 'object': 'chat.completion.chunk', 'created': timestamp, 'model': model, 'system_fingerprint': system_fingerprint, 'choices': [{ 'index': 0, 'delta': {'role': 'assistant'}, 'logprobs': None, 'finish_reason': None }] } yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n" # 将文本分成较大的块(每块约10个字符) text = result.response chunk_size = 50 chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] for chunk in chunks: data = { 'id': chat_id, 'object': 'chat.completion.chunk', 'created': timestamp, 'model': model, 'system_fingerprint': system_fingerprint, 'choices': [{ 'index': 0, 'delta': {'content': chunk}, 'logprobs': None, 'finish_reason': None }] } # 使用 ensure_ascii=False 确保中文正确显示 json_str = json.dumps(data, ensure_ascii=False) yield f"data: {json_str}\n\n" await asyncio.sleep(0.1) # 控制输出速度 # 发送结束消息 final_chunk = { 'id': chat_id, 'object': 'chat.completion.chunk', 'created': timestamp, 'model': model, 'system_fingerprint': system_fingerprint, 'choices': [{ 'index': 0, 'delta': {}, 'logprobs': None, 'finish_reason': 'stop' }] } yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" yield 'data: [DONE]\n\n' return StreamingResponse( stream_response(), media_type='text/event-stream' ) @app.get("/") async def root(): return "Hello from Docker!" if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8080)