han145 commited on
Commit
59f0dd7
·
verified ·
1 Parent(s): 1091096

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -132
app.py CHANGED
@@ -1,139 +1,236 @@
1
- import streamlit as st
2
- from llama_cpp import Llama
3
- from huggingface_hub import hf_hub_download
4
  import os
5
  import time
6
- import threading
7
-
8
- # ===== 配置区(CPU友好型模型)=====
9
- MODEL_REPO = "Qwen/Qwen2.5-1.5B-Instruct-GGUF"
10
- MODEL_FILENAME = "qwen2.5-1_5b-instruct-q4_k_m.gguf"
11
- MODEL_DIR = "/app/models" # Spaces持久化目录
12
-
13
- # 全局模型变量(避免重复加载)
14
- llm_instance = None
15
- model_loading = False
16
- model_error = None
17
-
18
- def background_model_load():
19
- """后台线程加载模型(避免阻塞Streamlit主线程)"""
20
- global llm_instance, model_loading, model_error
21
-
22
- if model_loading:
23
- return
24
-
25
- model_loading = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- # 创建目录
28
- os.makedirs(MODEL_DIR, exist_ok=True)
29
- model_path = os.path.join(MODEL_DIR, MODEL_FILENAME)
30
-
31
- # 检查是否已下载
32
- if not os.path.exists(model_path):
33
- st.session_state.download_status = "downloading"
34
- # 下载模型(自动断点续传)
35
- model_path = hf_hub_download(
36
- repo_id=MODEL_REPO,
37
- filename=MODEL_FILENAME,
38
- local_dir=MODEL_DIR,
39
- resume_download=True,
40
- token=None # 公开模型无需token
41
- )
42
- st.session_state.download_status = "downloaded"
43
-
44
- # 加载模型到内存
45
- st.session_state.download_status = "loading"
46
- start = time.time()
47
- llm_instance = Llama(
48
- model_path=model_path,
49
- n_ctx=2048,
50
- n_threads=4, # Spaces CPU通常4核
51
- n_gpu_layers=0, # 纯CPU
52
- verbose=False,
53
- n_batch=512 # 优化批处理
54
  )
55
- st.session_state.download_status = "ready"
56
- st.session_state.load_time = time.time() - start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
58
  except Exception as e:
59
- model_error = str(e)
60
- st.session_state.download_status = "error"
61
- finally:
62
- model_loading = False
63
-
64
- # ===== Streamlit UI =====
65
- st.set_page_config(page_title="🦙 CPU LLM Demo", page_icon="🦙", layout="wide")
66
-
67
- # 初始化状态
68
- if "download_status" not in st.session_state:
69
- st.session_state.download_status = "idle"
70
- st.session_state.load_time = 0
71
- # 启动后台加载线程
72
- threading.Thread(target=background_model_load, daemon=True).start()
73
-
74
- # 顶部状态栏
75
- col1, col2 = st.columns([3, 1])
76
- with col1:
77
- status_map = {
78
- "idle": "⏳ 准备加载模型...",
79
- "downloading": "⬇️ 正在下载模型 (1.0GB)...",
80
- "downloaded": " 模型下载完成,正在加载到内存...",
81
- "loading": "🧠 正在加载模型到内存(约60-90秒)...",
82
- "ready": f"✅ 模型就绪!加载耗时 {st.session_state.load_time:.1f} 秒",
83
- "error": f"❌ 加载失败: {model_error}"
84
  }
85
- st.info(status_map.get(st.session_state.download_status, "❓ 未知状态"))
86
- with col2:
87
- st.caption("💡 首次加载需1-2分钟 | 休眠后需重新下载")
88
-
89
- # 模型未就绪时禁止聊天
90
- if st.session_state.download_status != "ready":
91
- st.stop()
92
-
93
- # 聊天界面
94
- st.title("🦙 本地CPU大模型 (Qwen2.5-1.5B)")
95
- st.caption("完全离线运行 · 无外部API调用 · 适合演示用途")
96
-
97
- if "messages" not in st.session_state:
98
- st.session_state.messages = []
99
-
100
- # 显示历史消息
101
- for msg in st.session_state.messages:
102
- with st.chat_message(msg["role"]):
103
- st.markdown(msg["content"])
104
-
105
- # 用户输入
106
- if prompt := st.chat_input("问点什么吧..."):
107
- # 保存用户消息
108
- st.session_state.messages.append({"role": "user", "content": prompt})
109
- with st.chat_message("user"):
110
- st.markdown(prompt)
111
-
112
- # 生成回复
113
- with st.chat_message("assistant"):
114
- message_placeholder = st.empty()
115
- full_response = ""
116
-
117
- # Qwen2.5对话模板
118
- messages = [
119
- {"role": "system", "content": "You are a helpful assistant."},
120
- *[{"role": m["role"], "content": m["content"]} for m in st.session_state.messages]
121
- ]
122
-
123
- # 流式生成(CPU较慢,需耐心)
124
- try:
125
- for chunk in llm_instance.create_chat_completion(
126
- messages=messages,
127
- max_tokens=256, # 限制长度避免超时
128
- temperature=0.7,
129
- stream=True
130
- ):
131
- delta = chunk["choices"][0]["delta"]
132
- if "content" in delta:
133
- full_response += delta["content"]
134
- message_placeholder.markdown(full_response + "▌")
135
- message_placeholder.markdown(full_response)
136
- st.session_state.messages.append({"role": "assistant", "content": full_response})
137
- except Exception as e:
138
- st.error(f"生成失败: {str(e)}")
139
- message_placeholder.markdown(" 生成超时,请缩短问题长度重试")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import logging
4
+ from fastapi import FastAPI, Request, HTTPException, Depends, status
5
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
6
+ from fastapi.responses import JSONResponse
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import torch
9
+ import gc
10
+
11
+ # 日志配置
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # 全局变量
16
+ model = None
17
+ tokenizer = None
18
+
19
+ # 配置
20
+ MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
21
+ MAX_TOKENS = 512
22
+ DEVICE = "cpu" # 强制使用 CPU
23
+
24
+ # API 密钥配置
25
+ API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
26
+ API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
27
+
28
+ # Bearer 认证
29
+ security = HTTPBearer()
30
+
31
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
32
+ """验证 API 密钥"""
33
+ if not API_AUTH_ENABLED:
34
+ return True
35
+ if credentials.scheme != "Bearer":
36
+ raise HTTPException(
37
+ status_code=status.HTTP_401_UNAUTHORIZED,
38
+ detail="Invalid authentication scheme. Use 'Bearer' token",
39
+ headers={"WWW-Authenticate": "Bearer"},
40
+ )
41
+ api_key = credentials.credentials
42
+ if api_key not in API_KEYS:
43
+ raise HTTPException(
44
+ status_code=status.HTTP_401_UNAUTHORIZED,
45
+ detail="Invalid API key",
46
+ headers={"WWW-Authenticate": "Bearer"},
47
+ )
48
+ return True
49
+
50
+ def load_model():
51
+ """加载模型"""
52
+ global model, tokenizer
53
  try:
54
+ logger.info(f"开始加载模型: {MODEL_NAME}")
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ MODEL_NAME,
57
+ trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ MODEL_NAME,
64
+ torch_dtype=torch.float16,
65
+ device_map=None,
66
+ low_cpu_mem_usage=True,
67
+ trust_remote_code=True
68
+ )
69
+ model = model.to(DEVICE)
70
+ model.eval()
71
+ logger.info("模型加载成功")
72
+ return True
73
+ except Exception as e:
74
+ logger.error(f"模型加载失败: {e}")
75
+ return False
76
+
77
+ def apply_chat_template(messages):
78
+ """将 messages 转换为 Qwen 的对话格式"""
79
+ text = ""
80
+ for msg in messages:
81
+ role = msg.get("role", "").lower()
82
+ content = msg.get("content", "")
83
+
84
+ # 处理 content 可能是 list 的情况(兼容多模态格式)
85
+ if isinstance(content, list):
86
+ text_parts = []
87
+ for item in content:
88
+ if isinstance(item, dict):
89
+ if item.get("type") == "text":
90
+ text_parts.append(str(item.get("text", "")))
91
+ elif isinstance(item, str):
92
+ text_parts.append(item)
93
+ content_str = " ".join([p for p in text_parts if p]).strip()
94
+ else:
95
+ content_str = str(content).strip()
96
+
97
+ if not content_str:
98
+ continue
99
+
100
+ if role == "system":
101
+ text += f"<|im_start|>system\n{content_str}<|im_end|>\n"
102
+ elif role == "user":
103
+ text += f"<|im_start|>user\n{content_str}<|im_end|>\n"
104
+ elif role == "assistant":
105
+ text += f"<|im_start|>assistant\n{content_str}<|im_end|>\n"
106
+
107
+ text += "<|im_start|>assistant\n"
108
+ return text
109
+
110
+ def generate_chat_response(messages, max_tokens=512, temperature=0.7):
111
+ """生成回复"""
112
+ if model is None or tokenizer is None:
113
+ return {"error": "模型未加载"}
114
+
115
+ try:
116
+ prompt = apply_chat_template(messages)
117
+ logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}")
118
+
119
+ inputs = tokenizer(
120
+ [prompt],
121
+ return_tensors="pt",
122
+ truncation=True,
123
+ max_length=2048, # 改小,防止上下文过长影响生成
124
+ padding=True
125
+ )
126
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
127
+
128
+ with torch.no_grad():
129
+ outputs = model.generate(
130
+ **inputs,
131
+ max_new_tokens=384, # 暂时写死为 384,确保有足够生成空间
132
+ do_sample=True,
133
+ temperature=temperature,
134
+ top_p=0.85,
135
+ repetition_penalty=1.05,
136
+ pad_token_id=tokenizer.eos_token_id,
137
+ eos_token_id=tokenizer.eos_token_id,
138
+ )
139
+
140
 
141
+ return {"text": response}
142
+
143
  except Exception as e:
144
+ logger.error(f"生成失败: {str(e)}", exc_info=True)
145
+ return {"error": str(e)}
146
+
147
+ # FastAPI 应用
148
+ app = FastAPI(
149
+ title="Qwen OpenAI-compatible API",
150
+ version="1.0",
151
+ description="仅提供 /v1/chat/completions 端点"
152
+ )
153
+
154
+ @app.on_event("startup")
155
+ async def startup_event():
156
+ if load_model():
157
+ logger.info("服务启动完成")
158
+ else:
159
+ logger.error("模型加载失败,服务可能无法正常工作")
160
+
161
+ # 健康检查
162
+ @app.get("/health")
163
+ async def health_check():
164
+ return {
165
+ "status": "healthy" if model is not None else "model loading failed",
166
+ "model_loaded": model is not None,
167
+ "timestamp": int(time.time())
 
168
  }
169
+
170
+ # 根路径
171
+ @app.get("/")
172
+ async def root():
173
+ return {"message": "Qwen API 服务运行中,仅支持 /v1/chat/completions"}
174
+
175
+ # 核心端点
176
+ @app.post("/v1/chat/completions")
177
+ async def create_chat_completion(
178
+ request: Request,
179
+ auth_valid: bool = Depends(verify_api_key)
180
+ ):
181
+ try:
182
+ data = await request.json()
183
+ messages = data.get("messages", [])
184
+ max_tokens = data.get("max_tokens", MAX_TOKENS)
185
+ temperature = data.get("temperature", 0.7)
186
+
187
+ logger.info(f"收到请求: messages_count={len(messages)}")
188
+
189
+ if not messages or not isinstance(messages, list):
190
+ raise ValueError("messages 必须是非空列表")
191
+
192
+ result = generate_chat_response(messages, max_tokens, temperature)
193
+
194
+ if "error" in result:
195
+ raise RuntimeError(result["error"])
196
+
197
+ response_data = {
198
+ "id": f"chatcmpl-{int(time.time()*1000)}",
199
+ "object": "chat.completion",
200
+ "created": int(time.time()),
201
+ "model": MODEL_NAME,
202
+ "choices": [
203
+ {
204
+ "index": 0,
205
+ "message": {
206
+ "role": "assistant",
207
+ "content": result["text"]
208
+ },
209
+ "finish_reason": "stop"
210
+ }
211
+ ]
212
+ }
213
+
214
+ return response_data
215
+
216
+ except Exception as e:
217
+ logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True)
218
+ return JSONResponse(
219
+ status_code=500,
220
+ content={
221
+ "error": {
222
+ "message": str(e),
223
+ "type": "internal_server_error"
224
+ }
225
+ }
226
+ )
227
+
228
+ if __name__ == "__main__":
229
+ import uvicorn
230
+ uvicorn.run(
231
+ app,
232
+ host="0.0.0.0",
233
+ port=7860,
234
+ workers=1,
235
+ log_level="info"
236
+ )