nameliu commited on
Commit
e1ec593
·
verified ·
1 Parent(s): f8cc3fc

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +25 -0
  2. search.py +227 -0
  3. server.py +378 -0
  4. start.sh +10 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && \
6
+ apt-get install -y git curl && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
10
+ apt-get update && \
11
+ apt-get install -y git-lfs && \
12
+ git lfs install && \
13
+ rm -rf /var/lib/apt/lists/*
14
+
15
+ COPY . /app
16
+
17
+ RUN pip install graphrag==1.0.0 fastapi uvicorn
18
+
19
+ COPY search.py /usr/local/lib/python3.10/site-packages/graphrag/query/structured_search/local_search/search.py
20
+
21
+ RUN chmod +x /app/start.sh
22
+
23
+ EXPOSE 8080
24
+
25
+ CMD ["/app/start.sh"]
search.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Microsoft Corporation.
2
+ # Licensed under the MIT License
3
+
4
+ """LocalSearch implementation."""
5
+
6
+ import logging
7
+ import time
8
+ from collections.abc import AsyncGenerator
9
+ from typing import Any
10
+
11
+ import tiktoken
12
+
13
+ from graphrag.prompts.query.local_search_system_prompt import (
14
+ LOCAL_SEARCH_SYSTEM_PROMPT,
15
+ )
16
+ from graphrag.query.context_builder.builders import LocalContextBuilder
17
+ from graphrag.query.context_builder.conversation_history import (
18
+ ConversationHistory,
19
+ )
20
+ from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
21
+ from graphrag.query.llm.text_utils import num_tokens
22
+ from graphrag.query.structured_search.base import BaseSearch, SearchResult
23
+
24
+ DEFAULT_LLM_PARAMS = {
25
+ "max_tokens": 1500,
26
+ "temperature": 0.0,
27
+ }
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ class LocalSearch(BaseSearch[LocalContextBuilder]):
33
+ """Search orchestration for local search mode."""
34
+
35
+ def __init__(
36
+ self,
37
+ llm: BaseLLM,
38
+ context_builder: LocalContextBuilder,
39
+ token_encoder: tiktoken.Encoding | None = None,
40
+ system_prompt: str | None = None,
41
+ response_type: str = "multiple paragraphs",
42
+ callbacks: list[BaseLLMCallback] | None = None,
43
+ llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
44
+ context_builder_params: dict | None = None,
45
+ ):
46
+ super().__init__(
47
+ llm=llm,
48
+ context_builder=context_builder,
49
+ token_encoder=token_encoder,
50
+ llm_params=llm_params,
51
+ context_builder_params=context_builder_params or {},
52
+ )
53
+ self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
54
+ self.callbacks = callbacks
55
+ self.response_type = response_type
56
+
57
+ async def asearch(
58
+ self,
59
+ query: str,
60
+ conversation_history: ConversationHistory | None = None,
61
+ **kwargs,
62
+ ) -> SearchResult:
63
+ """Build local search context that fits a single context window and generate answer for the user query."""
64
+ start_time = time.time()
65
+ search_prompt = ""
66
+ llm_calls, prompt_tokens, output_tokens = {}, {}, {}
67
+ context_result = self.context_builder.build_context(
68
+ query=query,
69
+ conversation_history=conversation_history,
70
+ **kwargs,
71
+ **self.context_builder_params,
72
+ )
73
+ llm_calls["build_context"] = context_result.llm_calls
74
+ prompt_tokens["build_context"] = context_result.prompt_tokens
75
+ output_tokens["build_context"] = context_result.output_tokens
76
+
77
+ log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
78
+ try:
79
+ if "drift_query" in kwargs:
80
+ drift_query = kwargs["drift_query"]
81
+ search_prompt = self.system_prompt.format(
82
+ context_data=context_result.context_chunks,
83
+ response_type=self.response_type,
84
+ global_query=drift_query,
85
+ )
86
+ else:
87
+ search_prompt = self.system_prompt.format(
88
+ context_data=context_result.context_chunks,
89
+ response_type=self.response_type,
90
+ )
91
+ search_messages = [
92
+ {"role": "system", "content": search_prompt},
93
+ {"role": "user", "content": query},
94
+ ]
95
+
96
+ response = await self.llm.agenerate(
97
+ messages=search_messages,
98
+ streaming=False,
99
+ callbacks=self.callbacks,
100
+ **self.llm_params,
101
+ )
102
+ llm_calls["response"] = 1
103
+ prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
104
+ output_tokens["response"] = num_tokens(response, self.token_encoder)
105
+
106
+ return SearchResult(
107
+ response=response,
108
+ context_data=context_result.context_records,
109
+ context_text=context_result.context_chunks,
110
+ completion_time=time.time() - start_time,
111
+ llm_calls=sum(llm_calls.values()),
112
+ prompt_tokens=sum(prompt_tokens.values()),
113
+ output_tokens=sum(output_tokens.values()),
114
+ llm_calls_categories=llm_calls,
115
+ prompt_tokens_categories=prompt_tokens,
116
+ output_tokens_categories=output_tokens,
117
+ )
118
+
119
+ except Exception:
120
+ log.exception("Exception in _asearch")
121
+ return SearchResult(
122
+ response="",
123
+ context_data=context_result.context_records,
124
+ context_text=context_result.context_chunks,
125
+ completion_time=time.time() - start_time,
126
+ llm_calls=1,
127
+ prompt_tokens=num_tokens(search_prompt, self.token_encoder),
128
+ output_tokens=0,
129
+ )
130
+
131
+ async def astream_search(
132
+ self,
133
+ query: str,
134
+ conversation_history: ConversationHistory | None = None,
135
+ ) -> AsyncGenerator:
136
+ """Build local search context that fits a single context window and generate answer for the user query."""
137
+ start_time = time.time()
138
+
139
+ context_result = self.context_builder.build_context(
140
+ query=query,
141
+ conversation_history=conversation_history,
142
+ **self.context_builder_params,
143
+ )
144
+ log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
145
+ search_prompt = self.system_prompt.format(
146
+ context_data=context_result.context_chunks, response_type=self.response_type
147
+ )
148
+ search_messages = [
149
+ {"role": "system", "content": search_prompt},
150
+ {"role": "user", "content": query},
151
+ ]
152
+
153
+ # send context records first before sending the reduce response
154
+ yield context_result.context_records
155
+ async for response in self.llm.astream_generate( # type: ignore
156
+ messages=search_messages,
157
+ callbacks=self.callbacks,
158
+ **self.llm_params,
159
+ ):
160
+ yield response
161
+
162
+ def search(
163
+ self,
164
+ query: str,
165
+ conversation_history: ConversationHistory | None = None,
166
+ **kwargs,
167
+ ) -> SearchResult:
168
+ """Build local search context that fits a single context window and generate answer for the user question."""
169
+ start_time = time.time()
170
+ search_prompt = ""
171
+ llm_calls, prompt_tokens, output_tokens = {}, {}, {}
172
+ context_result = self.context_builder.build_context(
173
+ query=query,
174
+ conversation_history=conversation_history,
175
+ **kwargs,
176
+ **self.context_builder_params,
177
+ )
178
+ llm_calls["build_context"] = context_result.llm_calls
179
+ prompt_tokens["build_context"] = context_result.prompt_tokens
180
+ output_tokens["build_context"] = context_result.output_tokens
181
+
182
+ log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
183
+ try:
184
+ search_prompt = self.system_prompt.format(
185
+ context_data=context_result.context_chunks,
186
+ response_type=self.response_type,
187
+ )
188
+ search_messages = [
189
+ {"role": "system", "content": search_prompt},
190
+ {"role": "user", "content": query},
191
+ ]
192
+
193
+ response = self.llm.generate(
194
+ messages=search_messages,
195
+ streaming=True,
196
+ callbacks=self.callbacks,
197
+ **self.llm_params,
198
+ )
199
+ llm_calls["response"] = 1
200
+ prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
201
+ output_tokens["response"] = num_tokens(response, self.token_encoder)
202
+
203
+ return SearchResult(
204
+ response=response,
205
+ context_data=context_result.context_records,
206
+ context_text=context_result.context_chunks,
207
+ completion_time=time.time() - start_time,
208
+ llm_calls=sum(llm_calls.values()),
209
+ prompt_tokens=sum(prompt_tokens.values()),
210
+ output_tokens=sum(output_tokens.values()),
211
+ llm_calls_categories=llm_calls,
212
+ prompt_tokens_categories=prompt_tokens,
213
+ output_tokens_categories=output_tokens,
214
+ )
215
+
216
+ except Exception:
217
+ log.exception("Exception in _map_response_single_batch")
218
+ return SearchResult(
219
+ response="",
220
+ context_data=context_result.context_records,
221
+ context_text=context_result.context_chunks,
222
+ completion_time=time.time() - start_time,
223
+ llm_calls=1,
224
+ prompt_tokens=num_tokens(search_prompt, self.token_encoder),
225
+ output_tokens=0,
226
+ )
227
+
server.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import tiktoken
4
+
5
+ from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
6
+ from graphrag.query.indexer_adapters import (
7
+ read_indexer_covariates,
8
+ read_indexer_entities,
9
+ read_indexer_relationships,
10
+ read_indexer_reports,
11
+ read_indexer_text_units,
12
+ )
13
+ from graphrag.query.llm.oai.chat_openai import ChatOpenAI
14
+ from graphrag.query.llm.oai.embedding import OpenAIEmbedding
15
+ from graphrag.query.llm.oai.typing import OpenaiApiType
16
+ from graphrag.query.question_gen.local_gen import LocalQuestionGen
17
+ from graphrag.query.structured_search.local_search.mixed_context import (
18
+ LocalSearchMixedContext,
19
+ )
20
+ from graphrag.query.structured_search.local_search.search import LocalSearch
21
+ from graphrag.vector_stores.lancedb import LanceDBVectorStore
22
+
23
+
24
+ # 定义不同数据集的配置
25
+ DATA_CONFIGS = {
26
+ "ghost": {
27
+ "input_dir": "/app/graphrag-data/data/the_bit_player",
28
+ "community_level": 2
29
+ },
30
+ "zhu_rongji": {
31
+ "input_dir": "/app/graphrag-data/data/the_bit_player",
32
+ "community_level": 2
33
+ }
34
+ }
35
+
36
+ api_key = os.environ['api_key']
37
+ llm_model = os.environ['llm_model']
38
+ embedding_model = os.environ['embedding_model']
39
+ api_base = os.environ['api_base']
40
+
41
+ llm = ChatOpenAI(
42
+ api_key=api_key,
43
+ api_base=api_base,
44
+ model=llm_model,
45
+ api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
46
+ max_retries=10,
47
+ )
48
+
49
+ token_encoder = tiktoken.get_encoding("cl100k_base")
50
+
51
+ text_embedder = OpenAIEmbedding(
52
+ api_key=api_key,
53
+ api_base=api_base,
54
+ api_type=OpenaiApiType.OpenAI,
55
+ model=embedding_model,
56
+ deployment_name=embedding_model,
57
+ max_retries=7,
58
+ )
59
+
60
+ # 将数据加载逻辑封装成函数
61
+ def load_data(input_dir, community_level):
62
+ lancedb_uri = f"{input_dir}/lancedb"
63
+
64
+ # 定义表名
65
+ COMMUNITY_REPORT_TABLE = "create_final_community_reports"
66
+ ENTITY_TABLE = "create_final_nodes"
67
+ ENTITY_EMBEDDING_TABLE = "create_final_entities"
68
+ RELATIONSHIP_TABLE = "create_final_relationships"
69
+ TEXT_UNIT_TABLE = "create_final_text_units"
70
+
71
+ # 读取数据
72
+ entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet")
73
+ entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet")
74
+ entities = read_indexer_entities(entity_df, entity_embedding_df, community_level)
75
+
76
+ # 创建向量存储
77
+ description_embedding_store = LanceDBVectorStore(
78
+ collection_name="default-entity-description",
79
+ )
80
+ description_embedding_store.connect(db_uri=lancedb_uri)
81
+
82
+ relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet")
83
+ relationships = read_indexer_relationships(relationship_df)
84
+
85
+ report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet")
86
+ reports = read_indexer_reports(report_df, entity_df, community_level)
87
+
88
+ text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet")
89
+ text_units = read_indexer_text_units(text_unit_df)
90
+
91
+ return entities, description_embedding_store, relationships, reports, text_units
92
+
93
+ # 创建缓存字典来存储不同模型的搜索引擎实例
94
+ search_engines = {}
95
+
96
+ # 初始化函数
97
+ def initialize_search_engine(model_name):
98
+ if model_name not in DATA_CONFIGS:
99
+ raise ValueError(f"Unknown model: {model_name}")
100
+
101
+ config = DATA_CONFIGS[model_name]
102
+ # print(config)
103
+ entities, description_embedding_store, relationships, reports, text_units = load_data(
104
+ config["input_dir"],
105
+ config["community_level"]
106
+ )
107
+
108
+ context_builder = LocalSearchMixedContext(
109
+ community_reports=reports,
110
+ text_units=text_units,
111
+ entities=entities,
112
+ relationships=relationships,
113
+ covariates=None,
114
+ entity_text_embeddings=description_embedding_store,
115
+ embedding_vectorstore_key=EntityVectorStoreKey.ID,
116
+ text_embedder=text_embedder,
117
+ token_encoder=token_encoder,
118
+ )
119
+
120
+ local_context_params = {
121
+ "text_unit_prop": 0.5,
122
+ "community_prop": 0.1,
123
+ "conversation_history_max_turns": 5,
124
+ "conversation_history_user_turns_only": True,
125
+ "top_k_mapped_entities": 10,
126
+ "top_k_relationships": 10,
127
+ "include_entity_rank": True,
128
+ "include_relationship_weight": True,
129
+ "include_community_rank": False,
130
+ "return_candidate_context": False,
131
+ "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
132
+ "max_tokens": 36_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
133
+ }
134
+
135
+ llm_params = get_llm_params()
136
+ return create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params)
137
+
138
+
139
+ from fastapi import FastAPI, Request
140
+ from fastapi.responses import JSONResponse
141
+ import uvicorn
142
+ from datetime import datetime
143
+ import uuid
144
+ import time
145
+
146
+ app = FastAPI()
147
+
148
+ # 修改llm_params为动态配置
149
+ def get_llm_params(max_tokens=2000, temperature=0.0):
150
+ return {
151
+ "max_tokens": max_tokens,
152
+ "temperature": temperature,
153
+ }
154
+
155
+ def create_search_engine(llm, context_builder, token_encoder, llm_params, local_context_params):
156
+ return LocalSearch(
157
+ llm=llm,
158
+ context_builder=context_builder,
159
+ token_encoder=token_encoder,
160
+ llm_params=llm_params,
161
+ context_builder_params=local_context_params,
162
+ response_type="multiple paragraphs",
163
+ )
164
+
165
+
166
+ @app.post("/v1/completions")
167
+ async def completions(request: Request):
168
+ body = await request.json()
169
+
170
+ prompt = body.get("prompt", "hi")
171
+ max_tokens = body.get("max_tokens", 2000)
172
+ temperature = body.get("temperature", 0.0)
173
+ model = body.get("model", "ghost") # 默认使用ghost
174
+
175
+ # 检查模型是否已初始化
176
+ if model not in search_engines:
177
+ try:
178
+ search_engines[model] = initialize_search_engine(model)
179
+ except ValueError as e:
180
+ return JSONResponse(
181
+ content={"error": str(e)},
182
+ status_code=400
183
+ )
184
+
185
+ search_engine = search_engines[model]
186
+ llm_params = get_llm_params(max_tokens, temperature)
187
+ search_engine.llm_params = llm_params # 更新LLM参数
188
+
189
+ if prompt == "hi" or prompt == "":
190
+ result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}"
191
+ result = type('obj', (), {'response': result_text})()
192
+ else:
193
+ result = await search_engine.asearch(prompt)
194
+
195
+ # 计算token使用情况(这里需要根据你的实际token计算方法进行修改)
196
+ prompt_tokens = len(prompt.split()) # 简单示例,实际应使用proper tokenizer
197
+ completion_tokens = len(result.response.split())
198
+ total_tokens = prompt_tokens + completion_tokens
199
+
200
+ # 构建响应
201
+ response = {
202
+ "id": f"cmpl-{str(uuid.uuid4())[:8]}",
203
+ "object": "text_completion",
204
+ "created": int(time.time()),
205
+ "model": model,
206
+ "system_fingerprint": f"fp_{str(uuid.uuid4())[:8]}",
207
+ "choices": [
208
+ {
209
+ "text": result.response,
210
+ "index": 0,
211
+ "logprobs": None,
212
+ "finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop"
213
+ }
214
+ ],
215
+ "usage": {
216
+ "prompt_tokens": prompt_tokens,
217
+ "completion_tokens": completion_tokens,
218
+ "total_tokens": total_tokens
219
+ }
220
+ }
221
+
222
+ return JSONResponse(content=response)
223
+
224
+
225
+ from fastapi.responses import StreamingResponse
226
+ import json
227
+ import asyncio
228
+
229
+ @app.post("/api/v1/chat/completions")
230
+ async def chat_completions(request: Request):
231
+ body = await request.json()
232
+
233
+ # Extracting parameters from request body
234
+ model = body.get("model", "ghost") # Default model
235
+ messages = body.get("messages", [])
236
+ temperature = body.get("temperature", 0.0)
237
+ max_tokens = body.get("max_tokens", 2000)
238
+ stream = body.get("stream", False) # 获取stream参数
239
+
240
+ # Extracting user's prompt from messages
241
+ user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "")
242
+
243
+ # Check if the model exists in initialized search engines
244
+ if model not in search_engines:
245
+ try:
246
+ search_engines[model] = initialize_search_engine(model)
247
+ except ValueError as e:
248
+ return JSONResponse(
249
+ content={"error": str(e)},
250
+ status_code=400
251
+ )
252
+
253
+ # Initialize search engine and LLM parameters
254
+ search_engine = search_engines[model]
255
+ llm_params = get_llm_params(max_tokens, temperature)
256
+ search_engine.llm_params = llm_params
257
+
258
+ # Handle 'empty' prompts to list available models
259
+ if user_message == "" or user_message == "hi":
260
+ result_text = f"当前模型 {model} 已加载。可用模型: {', '.join(DATA_CONFIGS.keys())}"
261
+ result = type('obj', (), {'response': result_text})()
262
+ else:
263
+ # Fetch completions from search engine
264
+ result = await search_engine.asearch(user_message)
265
+
266
+ if not stream:
267
+ # 非流式响应,返回完整的响应
268
+ # Token usage calculation
269
+ prompt_tokens = len(user_message.split())
270
+ completion_tokens = len(result.response.split())
271
+ total_tokens = prompt_tokens + completion_tokens
272
+
273
+ completion_tokens_details = {
274
+ "reasoning_tokens": 0,
275
+ "accepted_prediction_tokens": 0,
276
+ "rejected_prediction_tokens": 0
277
+ }
278
+
279
+ response = {
280
+ "id": f"chatcmpl-{str(uuid.uuid4())[:8]}",
281
+ "object": "chat.completion",
282
+ "created": int(time.time()),
283
+ "model": model,
284
+ "usage": {
285
+ "prompt_tokens": prompt_tokens,
286
+ "completion_tokens": completion_tokens,
287
+ "total_tokens": total_tokens,
288
+ "completion_tokens_details": completion_tokens_details
289
+ },
290
+ "choices": [
291
+ {
292
+ "message": {
293
+ "role": "assistant",
294
+ "content": result.response
295
+ },
296
+ "logprobs": None,
297
+ "finish_reason": "length" if len(result.response.split()) >= max_tokens else "stop",
298
+ "index": 0
299
+ }
300
+ ]
301
+ }
302
+ return JSONResponse(content=response)
303
+
304
+ async def stream_response():
305
+ chat_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
306
+ system_fingerprint = f"fp_{str(uuid.uuid4())[:8]}"
307
+ timestamp = int(time.time())
308
+
309
+ # 发送role消息
310
+ first_chunk = {
311
+ 'id': chat_id,
312
+ 'object': 'chat.completion.chunk',
313
+ 'created': timestamp,
314
+ 'model': model,
315
+ 'system_fingerprint': system_fingerprint,
316
+ 'choices': [{
317
+ 'index': 0,
318
+ 'delta': {'role': 'assistant'},
319
+ 'logprobs': None,
320
+ 'finish_reason': None
321
+ }]
322
+ }
323
+ yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n"
324
+
325
+ # 将文本分成较大的块(每块约10个字符)
326
+ text = result.response
327
+ chunk_size = 50
328
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
329
+
330
+ for chunk in chunks:
331
+ data = {
332
+ 'id': chat_id,
333
+ 'object': 'chat.completion.chunk',
334
+ 'created': timestamp,
335
+ 'model': model,
336
+ 'system_fingerprint': system_fingerprint,
337
+ 'choices': [{
338
+ 'index': 0,
339
+ 'delta': {'content': chunk},
340
+ 'logprobs': None,
341
+ 'finish_reason': None
342
+ }]
343
+ }
344
+ # 使用 ensure_ascii=False 确保中文正确显示
345
+ json_str = json.dumps(data, ensure_ascii=False)
346
+ yield f"data: {json_str}\n\n"
347
+ await asyncio.sleep(0.1) # 控制输出速度
348
+
349
+ # 发送结束消息
350
+ final_chunk = {
351
+ 'id': chat_id,
352
+ 'object': 'chat.completion.chunk',
353
+ 'created': timestamp,
354
+ 'model': model,
355
+ 'system_fingerprint': system_fingerprint,
356
+ 'choices': [{
357
+ 'index': 0,
358
+ 'delta': {},
359
+ 'logprobs': None,
360
+ 'finish_reason': 'stop'
361
+ }]
362
+ }
363
+ yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
364
+ yield 'data: [DONE]\n\n'
365
+
366
+ return StreamingResponse(
367
+ stream_response(),
368
+ media_type='text/event-stream'
369
+ )
370
+
371
+ @app.get("/")
372
+ async def root():
373
+ return "Hello from Docker!"
374
+
375
+ if __name__ == "__main__":
376
+ uvicorn.run(app, host="0.0.0.0", port=8080)
377
+
378
+
start.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd /app
4
+ git clone https://huggingface.co/datasets/nameliu/graphrag-data
5
+ cd graphrag-data
6
+ git checkout master
7
+ git lfs pull
8
+
9
+ cd /app
10
+ python3 server.py