fiewolf1000 commited on
Commit
91bd68b
·
verified ·
1 Parent(s): 75911cc

Rename app.py to inference_node.py

Browse files
Files changed (2) hide show
  1. app.py +0 -160
  2. inference_node.py +89 -0
app.py DELETED
@@ -1,160 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Request, Header
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- import os
5
- import numpy as np
6
- from sentence_transformers import SentenceTransformer
7
- from typing import List, Optional, Union # 导入Union
8
- import logging
9
-
10
- # 配置日志
11
- logging.basicConfig(
12
- level=logging.INFO,
13
- format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
14
- handlers=[
15
- logging.FileHandler("embedding_service.log"),
16
- logging.StreamHandler()
17
- ]
18
- )
19
- logger = logging.getLogger("embedding_service")
20
-
21
- app = FastAPI()
22
-
23
- # 允许跨域请求
24
- app.add_middleware(
25
- CORSMiddleware,
26
- allow_origins=["*"],
27
- allow_credentials=True,
28
- allow_methods=["*"],
29
- allow_headers=["*"],
30
- )
31
-
32
- # 模型映射:OpenAI模型名 → 开源模型名
33
- MODEL_MAPPING = {
34
- "text-embedding-3-small": "BAAI/bge-small-en-v1.5",
35
- "text-embedding-3-large": "BAAI/bge-large-en-v1.5",
36
- "bge-small-en-v1.5": "BAAI/bge-small-en-v1.5",
37
- "bge-large-en-v1.5": "BAAI/bge-large-en-v1.5"
38
- }
39
-
40
- # 加载模型(懒加载)
41
- models = {}
42
-
43
- def get_model(model_name: str):
44
- logger.info(f"尝试获取模型: {model_name}")
45
- # 1. 定义所有支持的模型(映射名 + 直接支持的模型名)
46
- supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-*
47
- model_to_load = MODEL_MAPPING.get(model_name, model_name)
48
-
49
- # 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400
50
- known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型
51
- if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)):
52
- error_msg = f"不支持的模型: {model_name}"
53
- logger.error(error_msg)
54
- raise HTTPException(status_code=400, detail=error_msg)
55
-
56
- # 3. 加载支持的模型(含合法机构前缀的模型)
57
- if model_name not in models:
58
- try:
59
- hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
60
- models[model_name] = SentenceTransformer(
61
- model_to_load,
62
- use_auth_token=hf_token
63
- )
64
- logger.info(f"模型 {model_name} 加载成功")
65
- except Exception as e:
66
- # 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截
67
- error_msg = f"加载模型 {model_name} 失败: {str(e)}"
68
- logger.error(error_msg)
69
- raise HTTPException(status_code=500, detail=error_msg)
70
- return models[model_name]
71
-
72
-
73
-
74
- # 验证API密钥
75
- def verify_api_key(authorization: Optional[str] = Header(None)):
76
- logger.info(f"Authorization头部内容: {authorization}")
77
- if not authorization or not authorization.startswith("Bearer "):
78
- raise HTTPException(status_code=401, detail="未提供有效的API密钥")
79
- api_key = authorization[len("Bearer "):]
80
- if api_key != os.getenv("API_KEY"):
81
- raise HTTPException(status_code=401, detail="无效的API密钥")
82
- logger.info("API密钥验证通过")
83
- return True
84
-
85
- # 请求体模型
86
- class EmbeddingRequest(BaseModel):
87
- input: Union[str, List[str]] # 支持str或List[str]
88
- model: str
89
- encoding_format: Optional[str] = "float"
90
-
91
- # 响应体模型
92
- class EmbeddingData(BaseModel):
93
- object: str = "embedding"
94
- embedding: List[float]
95
- index: int
96
-
97
- class EmbeddingResponse(BaseModel):
98
- object: str = "list"
99
- data: List[EmbeddingData]
100
- model: str
101
- usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
102
-
103
- @app.post("/v1/embeddings", response_model=EmbeddingResponse)
104
- async def create_embedding(
105
- request: Request,
106
- req: EmbeddingRequest,
107
- _: bool = Depends(verify_api_key)
108
- ):
109
- # 打印请求信息
110
- logger.info("\n===== 接收到的完整请求信息 =====")
111
- logger.info(f"请求方法: {request.method}")
112
- logger.info(f"请求URL: {request.url}")
113
- logger.info("请求头部:")
114
- for name, value in request.headers.items():
115
- logger.info(f" {name}: {value}")
116
- logger.info(f"请求体: {await request.body()}")
117
- logger.info("===============================\n")
118
-
119
- # 嵌入生成逻辑
120
- logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
121
- try:
122
- model = get_model(req.model)
123
- inputs = [req.input] if isinstance(req.input, str) else req.input
124
- logger.info(f"处理输入,文本数量: {len(inputs)}")
125
-
126
- logger.info("开始计算嵌入")
127
- embeddings = model.encode(inputs, normalize_embeddings=True)
128
- logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}")
129
-
130
- data = [
131
- EmbeddingData(embedding=embedding.tolist(), index=i)
132
- for i, embedding in enumerate(embeddings)
133
- ]
134
-
135
- prompt_tokens = sum(len(text.split()) for text in inputs)
136
- logger.info(f"估算token数: {prompt_tokens}")
137
-
138
- return EmbeddingResponse(
139
- data=data,
140
- model=req.model,
141
- usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
142
- )
143
- except Exception as e:
144
- raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}")
145
-
146
- @app.get("/health")
147
- async def health_check(request: Request):
148
- logger.info("\n===== 健康检查请求信息 =====")
149
- logger.info(f"请求方法: {request.method}")
150
- logger.info(f"请求URL: {request.url}")
151
- logger.info("请求头部:")
152
- for name, value in request.headers.items():
153
- logger.info(f" {name}: {value}")
154
- logger.info("===============================\n")
155
- return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())}
156
-
157
- if __name__ == "__main__":
158
- import uvicorn
159
- logger.info("启动服务")
160
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_node.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ import os
5
+ import logging
6
+ import torch
7
+ from transformers import (
8
+ AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, TextStreamer
10
+ )
11
+
12
+ # 1. 基础配置
13
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
14
+ logger = logging.getLogger("inference_node")
15
+ app = FastAPI(title="推理节点服务(单一模型)")
16
+
17
+ # 2. 模型配置(每个节点仅加载一个模型,通过环境变量指定)
18
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-2-0.5B-Instruct") # 节点启动时指定模型
19
+ hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
20
+
21
+ # 3. 4bit量化(适配16G内存)
22
+ bnb_config = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_use_double_quant=True,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_compute_dtype=torch.bfloat16
27
+ )
28
+
29
+ # 4. 加载模型(启动时加载,单一模型)
30
+ logger.info(f"加载模型:{MODEL_NAME}(4bit量化)")
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=hf_token, padding_side="right")
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_NAME,
34
+ quantization_config=bnb_config,
35
+ device_map="auto",
36
+ use_auth_token=hf_token
37
+ )
38
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
39
+ logger.info(f"模型加载完成:{MODEL_NAME}")
40
+
41
+ # 5. 请求模型(与总控约定的格式)
42
+ class NodeInferenceRequest(BaseModel):
43
+ prompt: str # 总控拼接好的完整Prompt(含用户上下文)
44
+ max_tokens: int = 1024
45
+
46
+ # 6. 流式推理接口(仅处理推理,不存上下文)
47
+ @app.post("/node/stream-infer")
48
+ async def stream_infer(req: NodeInferenceRequest, request: Request):
49
+ try:
50
+ # 模型生成(流式)
51
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
52
+ outputs = model.generate(
53
+ **inputs,
54
+ streamer=streamer,
55
+ max_new_tokens=req.max_tokens,
56
+ do_sample=True,
57
+ temperature=0.7,
58
+ pad_token_id=tokenizer.eos_token_id
59
+ )
60
+
61
+ # 逐段生成结果(按总控约定的格式返回)
62
+ def generate_chunks():
63
+ generated_text = ""
64
+ for token in outputs[0][len(inputs["input_ids"][0]):]:
65
+ # 检查客户端是否断开连接(避免无效生成)
66
+ if await request.is_disconnected():
67
+ logger.info("客户端断开连接,停止生成")
68
+ break
69
+ token_text = tokenizer.decode(token, skip_special_tokens=True)
70
+ generated_text += token_text
71
+ # 按总控约定的JSON格式返回(便于总控透传)
72
+ yield f'{{"chunk":"{token_text.replace('"', '\\"')}","finish":false}}\n'
73
+ # 生成结束标识
74
+ yield '{"chunk":"","finish":true}\n'
75
+
76
+ return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
77
+
78
+ except Exception as e:
79
+ logger.error(f"推理失败:{str(e)}")
80
+ raise HTTPException(status_code=500, detail=f"节点推理失败:{str(e)}")
81
+
82
+ # 7. 健康检查接口(总控用于节点状态检测)
83
+ @app.get("/node/health")
84
+ async def node_health():
85
+ return {"status": "healthy", "model": MODEL_NAME}
86
+
87
+ if __name__ == "__main__":
88
+ import uvicorn
89
+ uvicorn.run(app, host="0.0.0.0", port=7860)