Spaces:
Paused
Paused
| import os | |
| # 设置缓存目录,避免 /.cache 权限问题 | |
| os.environ["HF_HOME"] = "/tmp" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
| os.environ["HF_HUB_CACHE"] = "/tmp" | |
| import time | |
| import uuid | |
| from typing import List, Optional, Union, Dict, Any | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
| import json | |
| # 初始化FastAPI应用 | |
| app = FastAPI(title="Qwen Coder API", version="1.0.0") | |
| # 添加CORS中间件 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 全局变量 | |
| model = None | |
| tokenizer = None | |
| model_name = None | |
| # Pydantic模型定义 | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = 2048 | |
| stream: Optional[bool] = False | |
| top_p: Optional[float] = 0.9 | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: Message | |
| finish_reason: str | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: Usage | |
| class Model(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int | |
| owned_by: str = "qwen" | |
| class ModelListResponse(BaseModel): | |
| object: str = "list" | |
| data: List[Model] | |
| def load_model(): | |
| """加载Qwen Coder模型""" | |
| global model, tokenizer, model_name | |
| # 模型选择优先级列表 | |
| model_candidates = [ | |
| "Qwen/Qwen2.5-Coder-7B-Instruct", | |
| "Qwen/Qwen2.5-Coder-3B-Instruct", | |
| "Qwen/Qwen2.5-Coder-1.5B-Instruct" | |
| ] | |
| for candidate_model in model_candidates: | |
| try: | |
| print(f"Attempting to load model: {candidate_model}") | |
| # 先测试tokenizer加载 | |
| print("Loading tokenizer...") | |
| test_tokenizer = AutoTokenizer.from_pretrained( | |
| candidate_model, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| revision="main" | |
| ) | |
| # 如果tokenizer加载成功,继续加载模型 | |
| print("Loading model...") | |
| test_model = AutoModelForCausalLM.from_pretrained( | |
| candidate_model, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| revision="main" | |
| ) | |
| # 成功加载后赋值给全局变量 | |
| tokenizer = test_tokenizer | |
| model = test_model | |
| model_name = candidate_model | |
| print(f"Successfully loaded model: {candidate_model}") | |
| return | |
| except Exception as e: | |
| print(f"Failed to load {candidate_model}: {str(e)}") | |
| continue | |
| # 如果所有模型都失败,抛出异常 | |
| raise Exception("Failed to load any Qwen model. Please check your configuration.") | |
| def format_messages_simple(messages: List[Message]) -> str: | |
| """简单的消息格式化(备用方案)""" | |
| formatted = "" | |
| for msg in messages: | |
| if msg.role == "system": | |
| formatted += f"System: {msg.content}\n\n" | |
| elif msg.role == "user": | |
| formatted += f"User: {msg.content}\n\n" | |
| elif msg.role == "assistant": | |
| formatted += f"Assistant: {msg.content}\n\n" | |
| formatted += "Assistant: " | |
| return formatted | |
| def format_messages(messages: List[Message]) -> str: | |
| """将消息格式化为Qwen格式""" | |
| try: | |
| formatted_messages = [] | |
| for msg in messages: | |
| formatted_messages.append({ | |
| "role": msg.role, | |
| "content": msg.content | |
| }) | |
| # 尝试使用tokenizer的chat template | |
| if hasattr(tokenizer, 'apply_chat_template'): | |
| text = tokenizer.apply_chat_template( | |
| formatted_messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return text | |
| else: | |
| # 如果没有chat_template,使用简单格式化 | |
| return format_messages_simple(messages) | |
| except Exception as e: | |
| print(f"Error in format_messages, using simple format: {str(e)}") | |
| return format_messages_simple(messages) | |
| def generate_response(prompt: str, temperature: float, max_tokens: int, top_p: float) -> str: | |
| """生成模型响应""" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) | |
| # 移动到模型设备 | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| # 设置生成参数 | |
| generation_config = GenerationConfig( | |
| max_new_tokens=min(max_tokens, 2048), | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| outputs = model.generate( | |
| **inputs, | |
| generation_config=generation_config | |
| ) | |
| # 只返回新生成的部分 | |
| response = tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| print(f"Error in generate_response: {str(e)}") | |
| return f"抱歉,生成响应时出现错误: {str(e)}" | |
| async def startup_event(): | |
| """应用启动时加载模型""" | |
| try: | |
| load_model() | |
| except Exception as e: | |
| print(f"Failed to load model during startup: {str(e)}") | |
| # 不要让启动失败,而是在请求时返回错误 | |
| async def root(): | |
| return { | |
| "message": "Qwen Coder API Server is running!", | |
| "model_loaded": model is not None, | |
| "current_model": model_name | |
| } | |
| async def list_models(): | |
| """列出可用模型""" | |
| model_id = "qwen2.5-coder-7b-instruct" if model_name is None else model_name.split("/")[-1].lower() | |
| return ModelListResponse( | |
| data=[ | |
| Model( | |
| id=model_id, | |
| created=int(time.time()), | |
| owned_by="qwen" | |
| ) | |
| ] | |
| ) | |
| async def chat_completions(request: ChatCompletionRequest): | |
| """处理聊天补全请求""" | |
| try: | |
| if model is None or tokenizer is None: | |
| # 尝试重新加载模型 | |
| try: | |
| load_model() | |
| except: | |
| raise HTTPException(status_code=503, detail="Model not loaded and failed to load on demand") | |
| # 格式化消息 | |
| prompt = format_messages(request.messages) | |
| # 生成响应 | |
| response_text = generate_response( | |
| prompt, | |
| request.temperature, | |
| request.max_tokens, | |
| request.top_p | |
| ) | |
| # 构造响应 | |
| completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" | |
| # 计算token使用量(简化版本) | |
| try: | |
| prompt_tokens = len(tokenizer.encode(prompt)) | |
| completion_tokens = len(tokenizer.encode(response_text)) | |
| except: | |
| # 如果tokenizer编码失败,使用估算 | |
| prompt_tokens = len(prompt.split()) * 2 | |
| completion_tokens = len(response_text.split()) * 2 | |
| response = ChatCompletionResponse( | |
| id=completion_id, | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=Message(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens | |
| ) | |
| ) | |
| return response | |
| except Exception as e: | |
| print(f"Error processing request: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def health_check(): | |
| """健康检查端点""" | |
| return { | |
| "status": "healthy" if model is not None and tokenizer is not None else "unhealthy", | |
| "model_loaded": model is not None and tokenizer is not None, | |
| "current_model": model_name, | |
| "torch_version": torch.__version__, | |
| "cuda_available": torch.cuda.is_available(), | |
| "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0 | |
| } | |
| from typing import Any | |
| class AnthropicMessage(BaseModel): | |
| role: str | |
| content: Any # 注意,这里不设为 str,因为 Claude CLI 传的是 list | |
| class MessagesRequest(BaseModel): | |
| model: str | |
| messages: List[AnthropicMessage] | |
| max_tokens: Optional[int] = 2048 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.9 | |
| async def messages_endpoint(request: MessagesRequest): | |
| """ | |
| 兼容 Anthropic Claude CLI 的接口 | |
| """ | |
| try: | |
| if model is None or tokenizer is None: | |
| try: | |
| load_model() | |
| except: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # 把 content 数组拼接成纯文本 | |
| converted_messages = [] | |
| for msg in request.messages: | |
| if isinstance(msg.content, list): | |
| # 把每个 {"type":"text","text":"..."} 拼接 | |
| texts = [] | |
| for block in msg.content: | |
| if isinstance(block, dict) and "text" in block: | |
| texts.append(block["text"]) | |
| merged = "\n".join(texts) | |
| elif isinstance(msg.content, str): | |
| merged = msg.content | |
| else: | |
| merged = str(msg.content) | |
| converted_messages.append(Message(role=msg.role, content=merged)) | |
| # 使用原本的格式化和生成逻辑 | |
| prompt = format_messages(converted_messages) | |
| response_text = generate_response( | |
| prompt, | |
| request.temperature, | |
| request.max_tokens, | |
| request.top_p | |
| ) | |
| return { | |
| "id": f"msg-{uuid.uuid4().hex[:8]}", | |
| "type": "message", | |
| "role": "assistant", | |
| "content": [ | |
| {"type": "text", "text": response_text} | |
| ], | |
| "model": request.model, | |
| "stop_reason": "end_turn", | |
| "stop_sequence": None, | |
| "usage": { | |
| "input_tokens": len(tokenizer.encode(prompt)), | |
| "output_tokens": len(tokenizer.encode(response_text)) | |
| } | |
| } | |
| except Exception as e: | |
| print(f"Error processing /v1/messages request: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |