badanwang commited on
Commit
72967c5
·
verified ·
1 Parent(s): ed1d652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -50
app.py CHANGED
@@ -1,51 +1,80 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import os
 
 
 
 
4
 
5
  # --- 1. 配置与模型加载 ---
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # 从环境变量或默认值加载模型ID
8
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
  print(f"正在加载模型: {MODEL_ID}")
10
 
11
- # 加载分词器和模型
12
- # trust_remote_code=True 是加载Qwen等模型所必需的
13
- # device_map="auto" 会自动将模型分配到可用的硬件上(如GPU)
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_ID,
17
- torch_dtype="auto",
18
- device_map="auto",
19
- trust_remote_code=True
20
- )
21
- print("模型加载成功!")
 
 
 
 
 
 
 
22
 
 
23
 
24
- # --- 2. 核心推理函数 (API) ---
 
 
25
 
26
- def get_response(prompt: str, history: list[list[str]] = None):
27
- """
28
- 一个简单的函数,用于与模型进行单次对话。
29
 
30
- Args:
31
- prompt (str): 用户当前输入的问题。
32
- history (list[list[str]], optional): 对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。默认为 None。
33
 
34
- Returns:
35
- str: 模型生成的回复。
36
- """
 
 
 
37
  if history is None:
38
  history = []
39
 
40
- # 1. 构建消息列表
41
  messages = []
42
  for user_message, bot_message in history:
43
  messages.append({"role": "user", "content": user_message})
44
  messages.append({"role": "assistant", "content": bot_message})
45
  messages.append({"role": "user", "content": prompt})
46
 
47
- # 2. 应用聊天模板并进行分词
48
- # 这是与聊天模型正确交互的关键步骤
49
  input_ids = tokenizer.apply_chat_template(
50
  messages,
51
  add_generation_prompt=True,
@@ -53,8 +82,6 @@ def get_response(prompt: str, history: list[list[str]] = None):
53
  return_tensors="pt"
54
  ).to(model.device)
55
 
56
- # 3. 生成回复
57
- # 这是一个阻塞式调用,会等待模型生成完毕
58
  outputs = model.generate(
59
  input_ids,
60
  max_new_tokens=1024,
@@ -63,32 +90,20 @@ def get_response(prompt: str, history: list[list[str]] = None):
63
  top_p=0.9
64
  )
65
 
66
- # 4. 解码生成的文本
67
- # `outputs[0]` 包含了输入的token和新生成的token,我们需要切片只获取新生成的部分
68
  response_ids = outputs[0][input_ids.shape[-1]:]
69
  response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
70
 
71
  return response_text
72
 
73
- # --- 3. 使用示例 ---
74
-
75
- if __name__ == "__main__":
76
- # 示例1: 单轮对话
77
- print("\n--- 示例 1: 单轮对话 ---")
78
- question1 = "你好,你是谁?"
79
- print(f"用户: {question1}")
80
- answer1 = get_response(question1)
81
- print(f"模型: {answer1}")
82
-
83
- # 示例2: 多轮对话
84
- print("\n--- 示例 2: 多轮对话 ---")
85
- # 首先,定义一个对话历史
86
- chat_history = [
87
- ["用Python写一个快速排序", "当然,这是快速排序的Python实现:\n```python\ndef quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)\n\nprint(quick_sort())\n```"]
88
- ]
89
- question2 = "很好,你能解释一下它的工作原理吗?"
90
- print(f"历史: {chat_history}")
91
- print(f"用户: {question2}")
92
- # 调用时传入历史记录
93
- answer2 = get_response(question2, history=chat_history)
94
- print(f"模型: {answer2}")
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import os
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware # 导入 CORS
6
+ from pydantic import BaseModel, Field
7
+ from typing import List, Optional
8
 
9
  # --- 1. 配置与模型加载 ---
10
 
11
+ # 初始化 FastAPI 应用
12
+ app = FastAPI(
13
+ title="Qwen 模型 API",
14
+ description="一个简单的API,用于与微调的Qwen模型进行交互,并可从任何网页调用。",
15
+ version="1.0.0"
16
+ )
17
+
18
+ # --- 新增:添加CORS中间件 ---
19
+ # 这是允许浏览器JavaScript调用的关键改动。
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"], # 允许所有来源 (网站)
23
+ allow_credentials=True,
24
+ allow_methods=["*"], # 允许所有方法 (GET, POST 等)
25
+ allow_headers=["*"], # 允许所有请求头
26
+ )
27
+
28
+ # (文件的其余部分与之前相同)
29
+
30
  # 从环境变量或默认值加载模型ID
31
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
32
  print(f"正在加载模型: {MODEL_ID}")
33
 
34
+ # 使用一个全局字典来持有模型,避免重复加载
35
+ model_objects = {}
36
+
37
+ @app.on_event("startup")
38
+ async def load_model():
39
+ """在应用启动时加载模型和分词器"""
40
+ print("应用启动... 开始加载模型...")
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_ID,
44
+ torch_dtype="auto",
45
+ device_map="auto",
46
+ trust_remote_code=True
47
+ )
48
+ model_objects['tokenizer'] = tokenizer
49
+ model_objects['model'] = model
50
+ print("模型加载成功!")
51
+
52
 
53
+ # --- 2. 定义API的请求和响应数据结构 ---
54
 
55
+ class APIRequest(BaseModel):
56
+ prompt: str = Field(..., description="用户当前输入的问题。")
57
+ history: Optional[List[List[str]]] = Field(None, description="对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。")
58
 
59
+ class APIResponse(BaseModel):
60
+ response: str = Field(..., description="模型生成的回复。")
 
61
 
 
 
 
62
 
63
+ # --- 3. 核心推理函数 ---
64
+
65
+ def get_response(prompt: str, history: Optional[List[List[str]]] = None) -> str:
66
+ tokenizer = model_objects['tokenizer']
67
+ model = model_objects['model']
68
+
69
  if history is None:
70
  history = []
71
 
 
72
  messages = []
73
  for user_message, bot_message in history:
74
  messages.append({"role": "user", "content": user_message})
75
  messages.append({"role": "assistant", "content": bot_message})
76
  messages.append({"role": "user", "content": prompt})
77
 
 
 
78
  input_ids = tokenizer.apply_chat_template(
79
  messages,
80
  add_generation_prompt=True,
 
82
  return_tensors="pt"
83
  ).to(model.device)
84
 
 
 
85
  outputs = model.generate(
86
  input_ids,
87
  max_new_tokens=1024,
 
90
  top_p=0.9
91
  )
92
 
 
 
93
  response_ids = outputs[0][input_ids.shape[-1]:]
94
  response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
95
 
96
  return response_text
97
 
98
+
99
+ # --- 4. 创建 API 端点 ---
100
+
101
+ @app.post("/generate", response_model=APIResponse)
102
+ async def generate(request: APIRequest):
103
+ """接收用户输入并返回模型的生成结果。"""
104
+ response_text = get_response(request.prompt, request.history)
105
+ return APIResponse(response=response_text)
106
+
107
+ @app.get("/")
108
+ def read_root():
109
+ return {"message": "欢迎使用Qwen模型API。请向 /generate 端点发送POST请求。"}