| | from fastapi import FastAPI, Depends, HTTPException, status,Request |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from sentence_transformers import SentenceTransformer |
| | from pydantic import BaseModel, Field |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import uvicorn |
| | import tiktoken |
| | import numpy as np |
| | from scipy.interpolate import interp1d |
| | from typing import List, Literal, Optional, Union,Dict |
| | from sklearn.preprocessing import PolynomialFeatures |
| | import torch |
| | import os |
| | import time |
| |
|
| |
|
| | |
| | sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk') |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | if torch.cuda.is_available(): |
| | print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0)) |
| | else: |
| | print('本次加载模型的设备为CPU.') |
| | model = SentenceTransformer('Qwen/Qwen3-Embedding-4B',device=device) |
| |
|
| | |
| | security = HTTPBearer() |
| |
|
| |
|
| |
|
| | class ChatMessage(BaseModel): |
| | role: Literal["user", "assistant", "system"] |
| | content: str |
| |
|
| |
|
| | class DeltaMessage(BaseModel): |
| | role: Optional[Literal["user", "assistant", "system"]] = None |
| | content: Optional[str] = None |
| |
|
| | class ChatCompletionRequest(BaseModel): |
| | model: str |
| | messages: List[ChatMessage] |
| | temperature: Optional[float] = None |
| | top_p: Optional[float] = None |
| | max_length: Optional[int] = None |
| | stream: Optional[bool] = False |
| |
|
| |
|
| | class ChatCompletionResponseChoice(BaseModel): |
| | index: int |
| | message: ChatMessage |
| | finish_reason: Literal["stop", "length"] |
| |
|
| |
|
| | class ChatCompletionResponseStreamChoice(BaseModel): |
| | index: int |
| | delta: DeltaMessage |
| | finish_reason: Optional[Literal["stop", "length"]] |
| |
|
| | class ChatCompletionResponse(BaseModel): |
| | model: str |
| | object: Literal["chat.completion", "chat.completion.chunk"] |
| | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] |
| | created: Optional[int] = Field(default_factory=lambda: int(time.time())) |
| |
|
| | class EmbeddingRequest(BaseModel): |
| | input: List[str] |
| | model: str |
| |
|
| | class EmbeddingResponse(BaseModel): |
| | data: list |
| | model: str |
| | object: str |
| | usage: dict |
| |
|
| | def num_tokens_from_string(string: str) -> int: |
| | """Returns the number of tokens in a text string.""" |
| | encoding = tiktoken.get_encoding('cl100k_base') |
| | num_tokens = len(encoding.encode(string)) |
| | return num_tokens |
| |
|
| | |
| | def interpolate_vector(vector, target_length): |
| | original_indices = np.arange(len(vector)) |
| | target_indices = np.linspace(0, len(vector)-1, target_length) |
| | f = interp1d(original_indices, vector, kind='linear') |
| | return f(target_indices) |
| |
|
| | def expand_features(embedding, target_length): |
| | poly = PolynomialFeatures(degree=2) |
| | expanded_embedding = poly.fit_transform(embedding.reshape(1, -1)) |
| | expanded_embedding = expanded_embedding.flatten() |
| | if len(expanded_embedding) > target_length: |
| | |
| | expanded_embedding = expanded_embedding[:target_length] |
| | elif len(expanded_embedding) < target_length: |
| | |
| | expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding))) |
| | return expanded_embedding |
| |
|
| | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| | async def create_chat_completion(request: ChatCompletionRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): |
| | if credentials.credentials != sk_key: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail="Invalid authorization code", |
| | ) |
| | choice_data = ChatCompletionResponseChoice( |
| | index=0, |
| | message=ChatMessage(role="assistant", content='你说得对,但这个是向量模型不能对话'), |
| | finish_reason="stop" |
| | ) |
| |
|
| | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") |
| |
|
| | @app.post("/v1/embeddings", response_model=EmbeddingResponse) |
| | async def get_embeddings(http_request: Request, request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): |
| | client_host = http_request.client.host |
| | headers = http_request.headers |
| | print(f"Client IP: {client_host}") |
| | print(f"Request headers: {headers}") |
| | |
| | if credentials.credentials != sk_key: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail="Invalid authorization code", |
| | ) |
| | |
| | |
| | embeddings = [model.encode(text) for text in request.input] |
| | |
| |
|
| | |
| | |
| | |
| | embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings] |
| |
|
| | |
| | |
| | embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings] |
| | |
| | embeddings = [embedding.tolist() for embedding in embeddings] |
| | prompt_tokens = sum(len(text.split()) for text in request.input) |
| | total_tokens = sum(num_tokens_from_string(text) for text in request.input) |
| |
|
| | |
| | response = { |
| | "data": [ |
| | { |
| | "embedding": embedding, |
| | "index": index, |
| | "object": "embedding" |
| | } for index, embedding in enumerate(embeddings) |
| | ], |
| | "model": request.model, |
| | "object": "list", |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "total_tokens": total_tokens, |
| | } |
| | } |
| |
|
| | |
| | return response |
| |
|
| | if __name__ == "__main__": |
| | |
| |
|
| | uvicorn.run("localembedding:app", host='0.0.0.0', port=7860, workers=1) |
| |
|