Upload folder using huggingface_hub
Browse files- inference/.ipynb_checkpoints/app-checkpoint.py +423 -0
- inference/.ipynb_checkpoints/chat-checkpoint.py +68 -0
- inference/.ipynb_checkpoints/deepseek_service-checkpoint.py +384 -0
- inference/.ipynb_checkpoints/demo-checkpoint.py +76 -0
- inference/.ipynb_checkpoints/inference-checkpoint.py +43 -0
- inference/.ipynb_checkpoints/model_utils-checkpoint.py +120 -0
- inference/__pycache__/app.cpython-311.pyc +0 -0
- inference/__pycache__/deepseek_service.cpython-311.pyc +0 -0
- inference/__pycache__/model_utils.cpython-311.pyc +0 -0
- inference/app.py +423 -0
- inference/chat.py +68 -0
- inference/deepseek_service.py +384 -0
- inference/demo.py +79 -0
- inference/inference.py +43 -0
- inference/model_utils.py +124 -0
- inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg +0 -0
- inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg +0 -0
- inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg +0 -0
- inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg +0 -0
inference/.ipynb_checkpoints/app-checkpoint.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import uvicorn
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import uuid
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
from contextlib import asynccontextmanager
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import StreamingResponse
|
| 16 |
+
from fastapi.concurrency import run_in_threadpool
|
| 17 |
+
from model_utils import SkinGPTModel
|
| 18 |
+
from deepseek_service import get_deepseek_service, DeepSeekService
|
| 19 |
+
|
| 20 |
+
# === Configuration ===
|
| 21 |
+
MODEL_PATH = "../checkpoint"
|
| 22 |
+
TEMP_DIR = "./temp_uploads"
|
| 23 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# DeepSeek API Key
|
| 26 |
+
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
|
| 27 |
+
|
| 28 |
+
# Global DeepSeek service instance
|
| 29 |
+
deepseek_service: Optional[DeepSeekService] = None
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def lifespan(app: FastAPI):
|
| 33 |
+
"""应用生命周期管理"""
|
| 34 |
+
# 启动时初始化 DeepSeek 服务
|
| 35 |
+
await init_deepseek()
|
| 36 |
+
yield
|
| 37 |
+
print("\nShutting down service...")
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="SkinGPT-R1 皮肤诊断系统",
|
| 41 |
+
description="智能皮肤诊断助手",
|
| 42 |
+
version="1.0.0",
|
| 43 |
+
lifespan=lifespan
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# CORS配置 - 允许前端访问
|
| 47 |
+
app.add_middleware(
|
| 48 |
+
CORSMiddleware,
|
| 49 |
+
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 50 |
+
allow_credentials=True,
|
| 51 |
+
allow_methods=["*"],
|
| 52 |
+
allow_headers=["*"],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 全局变量存储状态
|
| 56 |
+
# chat_states: 存储对话历史 (List of messages for Qwen)
|
| 57 |
+
# pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
|
| 58 |
+
chat_states = {}
|
| 59 |
+
pending_images = {}
|
| 60 |
+
|
| 61 |
+
def parse_diagnosis_result(raw_text: str) -> dict:
|
| 62 |
+
"""
|
| 63 |
+
解析诊断结果中的think和answer标签
|
| 64 |
+
|
| 65 |
+
参数:
|
| 66 |
+
- raw_text: 原始诊断文本
|
| 67 |
+
|
| 68 |
+
返回:
|
| 69 |
+
- dict: 包含thinking, answer, raw字段的字典
|
| 70 |
+
"""
|
| 71 |
+
import re
|
| 72 |
+
|
| 73 |
+
# 尝试匹配完整的标签
|
| 74 |
+
think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
|
| 75 |
+
answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
|
| 76 |
+
|
| 77 |
+
thinking = None
|
| 78 |
+
answer = None
|
| 79 |
+
|
| 80 |
+
# 处理think标签
|
| 81 |
+
if think_match:
|
| 82 |
+
thinking = think_match.group(1).strip()
|
| 83 |
+
else:
|
| 84 |
+
# 尝试匹配未闭合的think标签(输出被截断的情况)
|
| 85 |
+
unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
|
| 86 |
+
if unclosed_think:
|
| 87 |
+
thinking = unclosed_think.group(1).strip()
|
| 88 |
+
|
| 89 |
+
# 处理answer标签
|
| 90 |
+
if answer_match:
|
| 91 |
+
answer = answer_match.group(1).strip()
|
| 92 |
+
else:
|
| 93 |
+
# 尝试匹配未闭合的answer标签
|
| 94 |
+
unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
|
| 95 |
+
if unclosed_answer:
|
| 96 |
+
answer = unclosed_answer.group(1).strip()
|
| 97 |
+
|
| 98 |
+
# 如果仍然没有找到answer,清理原始文本作为answer
|
| 99 |
+
if not answer:
|
| 100 |
+
# 移除所有标签及其内容
|
| 101 |
+
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
|
| 102 |
+
cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
|
| 103 |
+
cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
|
| 104 |
+
cleaned = cleaned.strip()
|
| 105 |
+
answer = cleaned if cleaned else raw_text
|
| 106 |
+
|
| 107 |
+
# 清理可能残留的标签
|
| 108 |
+
if answer:
|
| 109 |
+
answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
|
| 110 |
+
if thinking:
|
| 111 |
+
thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
|
| 112 |
+
|
| 113 |
+
# 处理 "Final Answer:" 格式,提取其后的内容
|
| 114 |
+
if answer:
|
| 115 |
+
final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
|
| 116 |
+
if final_answer_match:
|
| 117 |
+
answer = final_answer_match.group(1).strip()
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"thinking": thinking if thinking else None,
|
| 121 |
+
"answer": answer,
|
| 122 |
+
"raw": raw_text
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
print("Initializing Model Service...")
|
| 126 |
+
# 全局加载模型
|
| 127 |
+
gpt_model = SkinGPTModel(MODEL_PATH)
|
| 128 |
+
print("Service Ready.")
|
| 129 |
+
|
| 130 |
+
# 初始化 DeepSeek 服务(异步)
|
| 131 |
+
async def init_deepseek():
|
| 132 |
+
global deepseek_service
|
| 133 |
+
print("\nInitializing DeepSeek service...")
|
| 134 |
+
deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
|
| 135 |
+
if deepseek_service and deepseek_service.is_loaded:
|
| 136 |
+
print("DeepSeek service is ready!")
|
| 137 |
+
else:
|
| 138 |
+
print("DeepSeek service not available, will return raw results")
|
| 139 |
+
|
| 140 |
+
@app.post("/v1/upload/{state_id}")
|
| 141 |
+
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 142 |
+
"""
|
| 143 |
+
接收图片上传。
|
| 144 |
+
逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
|
| 145 |
+
"""
|
| 146 |
+
try:
|
| 147 |
+
# 1. 保存图片到本地临时文件
|
| 148 |
+
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 149 |
+
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 150 |
+
file_path = os.path.join(TEMP_DIR, unique_name)
|
| 151 |
+
|
| 152 |
+
with open(file_path, "wb") as buffer:
|
| 153 |
+
shutil.copyfileobj(file.file, buffer)
|
| 154 |
+
|
| 155 |
+
# 2. 记录图片路径等待下一次 predict 调用时使用
|
| 156 |
+
# 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
|
| 157 |
+
pending_images[state_id] = file_path
|
| 158 |
+
|
| 159 |
+
# 3. 初始化对话状态(如果是新会话)
|
| 160 |
+
if state_id not in chat_states:
|
| 161 |
+
chat_states[state_id] = []
|
| 162 |
+
|
| 163 |
+
return {"message": "Image uploaded successfully", "path": file_path}
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
|
| 167 |
+
|
| 168 |
+
@app.post("/v1/predict/{state_id}")
|
| 169 |
+
async def v1_predict(request: Request, state_id: str):
|
| 170 |
+
"""
|
| 171 |
+
接收文本并执行推理。
|
| 172 |
+
逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
|
| 173 |
+
"""
|
| 174 |
+
try:
|
| 175 |
+
data = await request.json()
|
| 176 |
+
except:
|
| 177 |
+
raise HTTPException(status_code=400, detail="Invalid JSON")
|
| 178 |
+
|
| 179 |
+
user_message = data.get("message", "")
|
| 180 |
+
if not user_message:
|
| 181 |
+
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 182 |
+
|
| 183 |
+
# 获取或初始化历史
|
| 184 |
+
history = chat_states.get(state_id, [])
|
| 185 |
+
|
| 186 |
+
# 构建当前轮次的用户内容
|
| 187 |
+
current_content = []
|
| 188 |
+
|
| 189 |
+
# 1. 检查是否有刚刚上传的图片
|
| 190 |
+
if state_id in pending_images:
|
| 191 |
+
img_path = pending_images.pop(state_id) # 取出并移除
|
| 192 |
+
current_content.append({"type": "image", "image": img_path})
|
| 193 |
+
|
| 194 |
+
# 如果是第一次对话,加上 System Prompt
|
| 195 |
+
if not history:
|
| 196 |
+
system_prompt = "You are a professional AI dermatology assistant. "
|
| 197 |
+
user_message = f"{system_prompt}\n\n{user_message}"
|
| 198 |
+
|
| 199 |
+
# 2. 添加文本
|
| 200 |
+
current_content.append({"type": "text", "text": user_message})
|
| 201 |
+
|
| 202 |
+
# 3. 更新历史
|
| 203 |
+
history.append({"role": "user", "content": current_content})
|
| 204 |
+
chat_states[state_id] = history
|
| 205 |
+
|
| 206 |
+
# 4. 运行推理 (在线程池中运行以防阻塞)
|
| 207 |
+
try:
|
| 208 |
+
response_text = await run_in_threadpool(
|
| 209 |
+
gpt_model.generate_response,
|
| 210 |
+
messages=history
|
| 211 |
+
)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
# 回滚历史(移除刚才出错的用户提问)
|
| 214 |
+
chat_states[state_id].pop()
|
| 215 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
| 216 |
+
|
| 217 |
+
# 5. 将回复加入历史
|
| 218 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 219 |
+
chat_states[state_id] = history
|
| 220 |
+
|
| 221 |
+
return {"message": response_text}
|
| 222 |
+
|
| 223 |
+
@app.post("/v1/reset/{state_id}")
|
| 224 |
+
async def reset_chat(state_id: str):
|
| 225 |
+
"""清除会话状态"""
|
| 226 |
+
if state_id in chat_states:
|
| 227 |
+
del chat_states[state_id]
|
| 228 |
+
if state_id in pending_images:
|
| 229 |
+
# 可选:删除临时文件
|
| 230 |
+
try:
|
| 231 |
+
os.remove(pending_images[state_id])
|
| 232 |
+
except:
|
| 233 |
+
pass
|
| 234 |
+
del pending_images[state_id]
|
| 235 |
+
return {"message": "Chat history reset"}
|
| 236 |
+
|
| 237 |
+
@app.get("/")
|
| 238 |
+
async def root():
|
| 239 |
+
"""根路径"""
|
| 240 |
+
return {
|
| 241 |
+
"name": "SkinGPT-R1 皮肤诊断系统",
|
| 242 |
+
"version": "1.0.0",
|
| 243 |
+
"status": "running",
|
| 244 |
+
"description": "智能皮肤诊断助手"
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
@app.get("/health")
|
| 248 |
+
async def health_check():
|
| 249 |
+
"""健康检查"""
|
| 250 |
+
return {
|
| 251 |
+
"status": "healthy",
|
| 252 |
+
"model_loaded": True
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
@app.post("/diagnose/stream")
|
| 256 |
+
async def diagnose_stream(
|
| 257 |
+
image: Optional[UploadFile] = File(None),
|
| 258 |
+
text: str = Form(...),
|
| 259 |
+
language: str = Form("zh"),
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
SSE流式诊断接口(用于前端)
|
| 263 |
+
支持图片上传和文本输入,返回真正的流式响应
|
| 264 |
+
使用 DeepSeek API 优化输出格式
|
| 265 |
+
"""
|
| 266 |
+
from queue import Queue, Empty
|
| 267 |
+
from threading import Thread
|
| 268 |
+
|
| 269 |
+
language = language if language in ("zh", "en") else "zh"
|
| 270 |
+
|
| 271 |
+
# 处理图片
|
| 272 |
+
pil_image = None
|
| 273 |
+
temp_image_path = None
|
| 274 |
+
|
| 275 |
+
if image:
|
| 276 |
+
contents = await image.read()
|
| 277 |
+
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 278 |
+
|
| 279 |
+
# 创建队列用于线程间通信
|
| 280 |
+
result_queue = Queue()
|
| 281 |
+
# 用于存储完整响应和解析结果
|
| 282 |
+
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 283 |
+
|
| 284 |
+
def run_generation():
|
| 285 |
+
"""在后台线程中运行流式生成"""
|
| 286 |
+
full_response = []
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# 构建消息
|
| 290 |
+
messages = []
|
| 291 |
+
current_content = []
|
| 292 |
+
|
| 293 |
+
# 添加系统提示
|
| 294 |
+
system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
|
| 295 |
+
|
| 296 |
+
# 如果有图片,保存到临时文件
|
| 297 |
+
if pil_image:
|
| 298 |
+
generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
|
| 299 |
+
pil_image.save(generation_result["temp_image_path"])
|
| 300 |
+
current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
|
| 301 |
+
|
| 302 |
+
# 添加文本
|
| 303 |
+
prompt = f"{system_prompt}\n\n{text}"
|
| 304 |
+
current_content.append({"type": "text", "text": prompt})
|
| 305 |
+
messages.append({"role": "user", "content": current_content})
|
| 306 |
+
|
| 307 |
+
# 流式生成 - 每个 chunk 立即放入队列
|
| 308 |
+
for chunk in gpt_model.generate_response_stream(
|
| 309 |
+
messages=messages,
|
| 310 |
+
max_new_tokens=2048,
|
| 311 |
+
temperature=0.7
|
| 312 |
+
):
|
| 313 |
+
full_response.append(chunk)
|
| 314 |
+
result_queue.put(("delta", chunk))
|
| 315 |
+
|
| 316 |
+
# 解析结果
|
| 317 |
+
response_text = "".join(full_response)
|
| 318 |
+
parsed = parse_diagnosis_result(response_text)
|
| 319 |
+
generation_result["full_response"] = full_response
|
| 320 |
+
generation_result["parsed"] = parsed
|
| 321 |
+
|
| 322 |
+
# 标记生成完成
|
| 323 |
+
result_queue.put(("generation_done", None))
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
result_queue.put(("error", str(e)))
|
| 327 |
+
|
| 328 |
+
async def event_generator():
|
| 329 |
+
"""异步生成SSE事件"""
|
| 330 |
+
# 在后台线程启动生成(非阻塞)
|
| 331 |
+
gen_thread = Thread(target=run_generation)
|
| 332 |
+
gen_thread.start()
|
| 333 |
+
|
| 334 |
+
loop = asyncio.get_event_loop()
|
| 335 |
+
|
| 336 |
+
# 从队列中读取并发送流式内容
|
| 337 |
+
while True:
|
| 338 |
+
try:
|
| 339 |
+
# 非阻塞获取
|
| 340 |
+
msg_type, data = await loop.run_in_executor(
|
| 341 |
+
None,
|
| 342 |
+
lambda: result_queue.get(timeout=0.1)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if msg_type == "generation_done":
|
| 346 |
+
# 流式生成完成,准备处理最终结果
|
| 347 |
+
break
|
| 348 |
+
elif msg_type == "delta":
|
| 349 |
+
yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
|
| 350 |
+
yield f"data: {yield_chunk}\n\n"
|
| 351 |
+
elif msg_type == "error":
|
| 352 |
+
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 353 |
+
gen_thread.join()
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
except Empty:
|
| 357 |
+
# 队列暂时为空,继续等待
|
| 358 |
+
await asyncio.sleep(0.01)
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
gen_thread.join()
|
| 362 |
+
|
| 363 |
+
# 获取解析结果
|
| 364 |
+
parsed = generation_result["parsed"]
|
| 365 |
+
if not parsed:
|
| 366 |
+
yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
raw_thinking = parsed["thinking"]
|
| 370 |
+
raw_answer = parsed["answer"]
|
| 371 |
+
|
| 372 |
+
# 使用 DeepSeek 优化结果
|
| 373 |
+
refined_by_deepseek = False
|
| 374 |
+
description = None
|
| 375 |
+
thinking = raw_thinking
|
| 376 |
+
answer = raw_answer
|
| 377 |
+
|
| 378 |
+
if deepseek_service and deepseek_service.is_loaded:
|
| 379 |
+
try:
|
| 380 |
+
print(f"Calling DeepSeek to refine diagnosis (language={language})...")
|
| 381 |
+
refined = await deepseek_service.refine_diagnosis(
|
| 382 |
+
raw_answer=raw_answer,
|
| 383 |
+
raw_thinking=raw_thinking,
|
| 384 |
+
language=language,
|
| 385 |
+
)
|
| 386 |
+
if refined["success"]:
|
| 387 |
+
description = refined["description"]
|
| 388 |
+
thinking = refined["analysis_process"]
|
| 389 |
+
answer = refined["diagnosis_result"]
|
| 390 |
+
refined_by_deepseek = True
|
| 391 |
+
print(f"DeepSeek refinement completed successfully")
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"DeepSeek refinement failed, using original: {e}")
|
| 394 |
+
else:
|
| 395 |
+
print("DeepSeek service not available, using raw results")
|
| 396 |
+
|
| 397 |
+
success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
|
| 398 |
+
|
| 399 |
+
# 返回格式与参考项目保持一致
|
| 400 |
+
final_payload = {
|
| 401 |
+
"description": description, # 图片描述(从 thinking 中提取)
|
| 402 |
+
"thinking": thinking, # 分析过程(DeepSeek 优化后)
|
| 403 |
+
"answer": answer, # 诊断结果(DeepSeek 优化后)
|
| 404 |
+
"raw": parsed["raw"], # 原始响应
|
| 405 |
+
"refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
|
| 406 |
+
"success": True,
|
| 407 |
+
"message": success_msg
|
| 408 |
+
}
|
| 409 |
+
yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
|
| 410 |
+
yield f"data: {yield_final}\n\n"
|
| 411 |
+
|
| 412 |
+
# 清理临时图片
|
| 413 |
+
temp_path = generation_result.get("temp_image_path")
|
| 414 |
+
if temp_path and os.path.exists(temp_path):
|
| 415 |
+
try:
|
| 416 |
+
os.remove(temp_path)
|
| 417 |
+
except:
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 421 |
+
|
| 422 |
+
if __name__ == '__main__':
|
| 423 |
+
uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
|
inference/.ipynb_checkpoints/chat-checkpoint.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chat.py
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
from model_utils import SkinGPTModel
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
|
| 8 |
+
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 9 |
+
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
# 初始化模型
|
| 13 |
+
bot = SkinGPTModel(args.model_path)
|
| 14 |
+
|
| 15 |
+
# 初始化对话历史
|
| 16 |
+
# 系统提示词
|
| 17 |
+
system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
|
| 18 |
+
|
| 19 |
+
# 构造第一条包含图片的消息
|
| 20 |
+
if not os.path.exists(args.image):
|
| 21 |
+
print(f"Error: Image {args.image} not found.")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
history = [
|
| 25 |
+
{
|
| 26 |
+
"role": "user",
|
| 27 |
+
"content": [
|
| 28 |
+
{"type": "image", "image": args.image},
|
| 29 |
+
{"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
|
| 35 |
+
print(f"Image loaded: {args.image}")
|
| 36 |
+
|
| 37 |
+
# 获取第一轮诊断
|
| 38 |
+
print("\nModel is thinking...", end="", flush=True)
|
| 39 |
+
response = bot.generate_response(history)
|
| 40 |
+
print(f"\rAssistant: {response}\n")
|
| 41 |
+
|
| 42 |
+
# 将助手的回复加入历史
|
| 43 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 44 |
+
|
| 45 |
+
# 进入多轮对话循环
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
user_input = input("User: ")
|
| 49 |
+
if user_input.lower() in ["exit", "quit"]:
|
| 50 |
+
break
|
| 51 |
+
if not user_input.strip():
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# 加入用户的新问题
|
| 55 |
+
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
| 56 |
+
|
| 57 |
+
print("Model is thinking...", end="", flush=True)
|
| 58 |
+
response = bot.generate_response(history)
|
| 59 |
+
print(f"\rAssistant: {response}\n")
|
| 60 |
+
|
| 61 |
+
# 加入助手的新回复
|
| 62 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 63 |
+
|
| 64 |
+
except KeyboardInterrupt:
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|
inference/.ipynb_checkpoints/deepseek_service-checkpoint.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepSeek API Service
|
| 3 |
+
Used to optimize and organize SkinGPT model output results
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from openai import AsyncOpenAI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DeepSeekService:
|
| 13 |
+
"""DeepSeek API Service Class"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 16 |
+
"""
|
| 17 |
+
Initialize DeepSeek service
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
api_key: DeepSeek API key, reads from environment variable if not provided
|
| 21 |
+
"""
|
| 22 |
+
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
| 23 |
+
self.base_url = "https://api.deepseek.com"
|
| 24 |
+
self.model = "deepseek-chat" # Using deepseek-chat model
|
| 25 |
+
|
| 26 |
+
self.client = None
|
| 27 |
+
self.is_loaded = False
|
| 28 |
+
|
| 29 |
+
print(f"DeepSeek API service initializing...")
|
| 30 |
+
print(f"API Base URL: {self.base_url}")
|
| 31 |
+
|
| 32 |
+
async def load(self):
|
| 33 |
+
"""Initialize DeepSeek API client"""
|
| 34 |
+
try:
|
| 35 |
+
if not self.api_key:
|
| 36 |
+
print("DeepSeek API key not provided")
|
| 37 |
+
self.is_loaded = False
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# Initialize OpenAI compatible client
|
| 41 |
+
self.client = AsyncOpenAI(
|
| 42 |
+
api_key=self.api_key,
|
| 43 |
+
base_url=self.base_url
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.is_loaded = True
|
| 47 |
+
print("DeepSeek API service is ready!")
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"DeepSeek API service initialization failed: {e}")
|
| 51 |
+
self.is_loaded = False
|
| 52 |
+
|
| 53 |
+
async def refine_diagnosis(
|
| 54 |
+
self,
|
| 55 |
+
raw_answer: str,
|
| 56 |
+
raw_thinking: Optional[str] = None,
|
| 57 |
+
language: str = "zh"
|
| 58 |
+
) -> dict:
|
| 59 |
+
"""
|
| 60 |
+
Use DeepSeek API to optimize and organize diagnosis results
|
| 61 |
+
|
| 62 |
+
Parameters:
|
| 63 |
+
raw_answer: Original diagnosis result
|
| 64 |
+
raw_thinking: AI thinking process
|
| 65 |
+
language: Language option
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Dictionary containing "description", "analysis_process" and "diagnosis_result"
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
if not self.is_loaded or self.client is None:
|
| 72 |
+
error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
|
| 73 |
+
print("DeepSeek API not initialized, returning original result")
|
| 74 |
+
return {
|
| 75 |
+
"success": False,
|
| 76 |
+
"description": "",
|
| 77 |
+
"analysis_process": raw_thinking or error_msg,
|
| 78 |
+
"diagnosis_result": raw_answer,
|
| 79 |
+
"original_diagnosis": raw_answer,
|
| 80 |
+
"error": "DeepSeek API not initialized"
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Build prompt
|
| 85 |
+
prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
|
| 86 |
+
|
| 87 |
+
# Select system prompt based on language
|
| 88 |
+
if language == "en":
|
| 89 |
+
system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
|
| 90 |
+
else:
|
| 91 |
+
system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
|
| 92 |
+
|
| 93 |
+
# Call DeepSeek API
|
| 94 |
+
response = await self.client.chat.completions.create(
|
| 95 |
+
model=self.model,
|
| 96 |
+
messages=[
|
| 97 |
+
{"role": "system", "content": system_content},
|
| 98 |
+
{"role": "user", "content": prompt}
|
| 99 |
+
],
|
| 100 |
+
temperature=0.1,
|
| 101 |
+
max_tokens=2048,
|
| 102 |
+
top_p=0.8,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Extract generated text
|
| 106 |
+
generated_text = response.choices[0].message.content
|
| 107 |
+
|
| 108 |
+
# Parse output
|
| 109 |
+
parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"success": True,
|
| 113 |
+
"description": parsed["description"],
|
| 114 |
+
"analysis_process": parsed["analysis_process"],
|
| 115 |
+
"diagnosis_result": parsed["diagnosis_result"],
|
| 116 |
+
"original_diagnosis": raw_answer,
|
| 117 |
+
"raw_refined": generated_text
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"DeepSeek API call failed: {e}")
|
| 122 |
+
error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
|
| 123 |
+
return {
|
| 124 |
+
"success": False,
|
| 125 |
+
"description": "",
|
| 126 |
+
"analysis_process": raw_thinking or error_msg,
|
| 127 |
+
"diagnosis_result": raw_answer,
|
| 128 |
+
"original_diagnosis": raw_answer,
|
| 129 |
+
"error": str(e)
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
|
| 133 |
+
"""
|
| 134 |
+
Build optimization prompt
|
| 135 |
+
|
| 136 |
+
Parameters:
|
| 137 |
+
raw_answer: Original diagnosis result
|
| 138 |
+
raw_thinking: AI thinking process
|
| 139 |
+
language: Language option, "zh" for Chinese, "en" for English
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Built prompt
|
| 143 |
+
"""
|
| 144 |
+
if language == "en":
|
| 145 |
+
# English prompt - organize and polish while preserving meaning
|
| 146 |
+
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 147 |
+
prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
|
| 148 |
+
|
| 149 |
+
【Requirements】
|
| 150 |
+
- Preserve the original tone and expression style
|
| 151 |
+
- Text 1 contains the thinking process, Text 2 contains the diagnosis result
|
| 152 |
+
- Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
|
| 153 |
+
- For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
|
| 154 |
+
- If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
|
| 155 |
+
- If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
|
| 156 |
+
- DO NOT infer or add new medical information, DO NOT output any meta-commentary
|
| 157 |
+
- You may adjust unreasonable statements or remove redundant content to improve clarity
|
| 158 |
+
|
| 159 |
+
[Text 1]
|
| 160 |
+
{thinking_text}
|
| 161 |
+
|
| 162 |
+
[Text 2]
|
| 163 |
+
{raw_answer}
|
| 164 |
+
|
| 165 |
+
【Output】Only output three sections, do not output anything else:
|
| 166 |
+
## Description
|
| 167 |
+
(Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
|
| 168 |
+
|
| 169 |
+
## Analysis Process
|
| 170 |
+
(Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
|
| 171 |
+
|
| 172 |
+
## Diagnosis Result
|
| 173 |
+
(The organized diagnosis result from Text 2)
|
| 174 |
+
|
| 175 |
+
【Example】:
|
| 176 |
+
## Description
|
| 177 |
+
The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
|
| 178 |
+
|
| 179 |
+
## Analysis Process
|
| 180 |
+
These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
|
| 181 |
+
|
| 182 |
+
## Diagnosis Result
|
| 183 |
+
Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
|
| 184 |
+
"""
|
| 185 |
+
else:
|
| 186 |
+
# Chinese prompt - translate to Simplified Chinese AND organize/polish
|
| 187 |
+
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 188 |
+
prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
|
| 189 |
+
|
| 190 |
+
【要求】
|
| 191 |
+
- 保留原文的语气和表达方式
|
| 192 |
+
- 文本1是思考过程,文本2是诊断结果
|
| 193 |
+
- 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
|
| 194 |
+
- 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
|
| 195 |
+
- 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
|
| 196 |
+
- 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
|
| 197 |
+
- 【文本1】和【文本2】需要翻译成简体中文
|
| 198 |
+
- 禁止推断或添加新的医学信息,禁止输出任何元评论
|
| 199 |
+
- 可以调整不合理的语句或去除冗余内容以提高清晰度
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
【文本1】
|
| 203 |
+
{thinking_text}
|
| 204 |
+
|
| 205 |
+
【文本2】
|
| 206 |
+
{raw_answer}
|
| 207 |
+
|
| 208 |
+
【输出】只输出三个部分,不要输出其他任何内容:
|
| 209 |
+
## 图像描述
|
| 210 |
+
(从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
|
| 211 |
+
|
| 212 |
+
## 分析过程
|
| 213 |
+
(提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
|
| 214 |
+
|
| 215 |
+
## 诊断结果
|
| 216 |
+
(整理后的诊断结果)
|
| 217 |
+
|
| 218 |
+
【样例】:
|
| 219 |
+
## 图像描述
|
| 220 |
+
图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
|
| 221 |
+
|
| 222 |
+
## 分析过程
|
| 223 |
+
这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
|
| 224 |
+
|
| 225 |
+
## 诊断结果
|
| 226 |
+
可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
return prompt
|
| 230 |
+
|
| 231 |
+
def _parse_refined_output(
|
| 232 |
+
self,
|
| 233 |
+
generated_text: str,
|
| 234 |
+
raw_answer: str,
|
| 235 |
+
raw_thinking: Optional[str] = None,
|
| 236 |
+
language: str = "zh"
|
| 237 |
+
) -> dict:
|
| 238 |
+
"""
|
| 239 |
+
Parse DeepSeek generated output
|
| 240 |
+
|
| 241 |
+
Parameters:
|
| 242 |
+
generated_text: DeepSeek generated text
|
| 243 |
+
raw_answer: Original diagnosis (as fallback)
|
| 244 |
+
raw_thinking: Original thinking process (as fallback)
|
| 245 |
+
language: Language option
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Dictionary containing description, analysis_process and diagnosis_result
|
| 249 |
+
"""
|
| 250 |
+
description = ""
|
| 251 |
+
analysis_process = None
|
| 252 |
+
diagnosis_result = None
|
| 253 |
+
|
| 254 |
+
if language == "en":
|
| 255 |
+
# English patterns
|
| 256 |
+
desc_match = re.search(
|
| 257 |
+
r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
|
| 258 |
+
generated_text,
|
| 259 |
+
re.IGNORECASE
|
| 260 |
+
)
|
| 261 |
+
analysis_match = re.search(
|
| 262 |
+
r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
|
| 263 |
+
generated_text,
|
| 264 |
+
re.IGNORECASE
|
| 265 |
+
)
|
| 266 |
+
result_match = re.search(
|
| 267 |
+
r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
|
| 268 |
+
generated_text,
|
| 269 |
+
re.IGNORECASE
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
desc_header = "## Description"
|
| 273 |
+
analysis_header = "## Analysis Process"
|
| 274 |
+
result_header = "## Diagnosis Result"
|
| 275 |
+
else:
|
| 276 |
+
# Chinese patterns
|
| 277 |
+
desc_match = re.search(
|
| 278 |
+
r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
|
| 279 |
+
generated_text
|
| 280 |
+
)
|
| 281 |
+
analysis_match = re.search(
|
| 282 |
+
r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
|
| 283 |
+
generated_text
|
| 284 |
+
)
|
| 285 |
+
result_match = re.search(
|
| 286 |
+
r'##\s*诊断结果\s*\n([\s\S]*?)$',
|
| 287 |
+
generated_text
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
desc_header = "## 图像描述"
|
| 291 |
+
analysis_header = "## 分析过程"
|
| 292 |
+
result_header = "## 诊断结果"
|
| 293 |
+
|
| 294 |
+
# Extract description
|
| 295 |
+
if desc_match:
|
| 296 |
+
description = desc_match.group(1).strip()
|
| 297 |
+
print(f"Successfully parsed description")
|
| 298 |
+
else:
|
| 299 |
+
print(f"Description parsing failed")
|
| 300 |
+
description = ""
|
| 301 |
+
|
| 302 |
+
# Extract analysis process
|
| 303 |
+
if analysis_match:
|
| 304 |
+
analysis_process = analysis_match.group(1).strip()
|
| 305 |
+
print(f"Successfully parsed analysis process")
|
| 306 |
+
else:
|
| 307 |
+
print(f"Analysis process parsing failed, trying other methods")
|
| 308 |
+
# Try to extract from generated text
|
| 309 |
+
result_pos = generated_text.find(result_header)
|
| 310 |
+
if result_pos > 0:
|
| 311 |
+
# Get content before diagnosis result
|
| 312 |
+
analysis_process = generated_text[:result_pos].strip()
|
| 313 |
+
# Remove possible headers
|
| 314 |
+
for header in [desc_header, analysis_header]:
|
| 315 |
+
header_escaped = re.escape(header)
|
| 316 |
+
analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
|
| 317 |
+
else:
|
| 318 |
+
# If no format at all, try to get first half
|
| 319 |
+
mid_point = len(generated_text) // 2
|
| 320 |
+
analysis_process = generated_text[:mid_point].strip()
|
| 321 |
+
|
| 322 |
+
# If still empty, use original content (final fallback)
|
| 323 |
+
if not analysis_process and raw_thinking:
|
| 324 |
+
print(f"Using original raw_thinking as fallback")
|
| 325 |
+
analysis_process = raw_thinking
|
| 326 |
+
|
| 327 |
+
# Extract diagnosis result
|
| 328 |
+
if result_match:
|
| 329 |
+
diagnosis_result = result_match.group(1).strip()
|
| 330 |
+
print(f"Successfully parsed diagnosis result")
|
| 331 |
+
else:
|
| 332 |
+
print(f"Diagnosis result parsing failed, trying other methods")
|
| 333 |
+
# Try to extract from generated text
|
| 334 |
+
result_pos = generated_text.find(result_header)
|
| 335 |
+
if result_pos > 0:
|
| 336 |
+
diagnosis_result = generated_text[result_pos:].strip()
|
| 337 |
+
# Remove possible header
|
| 338 |
+
result_header_escaped = re.escape(result_header)
|
| 339 |
+
diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
|
| 340 |
+
else:
|
| 341 |
+
# If no format at all, get second half
|
| 342 |
+
mid_point = len(generated_text) // 2
|
| 343 |
+
diagnosis_result = generated_text[mid_point:].strip()
|
| 344 |
+
|
| 345 |
+
# If still empty, use original content (final fallback)
|
| 346 |
+
if not diagnosis_result:
|
| 347 |
+
print(f"Using original raw_answer as fallback")
|
| 348 |
+
diagnosis_result = raw_answer
|
| 349 |
+
|
| 350 |
+
return {
|
| 351 |
+
"description": description,
|
| 352 |
+
"analysis_process": analysis_process,
|
| 353 |
+
"diagnosis_result": diagnosis_result
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# Global DeepSeek service instance (lazy loading)
|
| 358 |
+
_deepseek_service: Optional[DeepSeekService] = None
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
|
| 362 |
+
"""
|
| 363 |
+
Get DeepSeek service instance (singleton pattern)
|
| 364 |
+
|
| 365 |
+
Parameters:
|
| 366 |
+
api_key: Optional API key to use
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
DeepSeekService instance, or None if API initialization fails
|
| 370 |
+
"""
|
| 371 |
+
global _deepseek_service
|
| 372 |
+
|
| 373 |
+
if _deepseek_service is None:
|
| 374 |
+
try:
|
| 375 |
+
_deepseek_service = DeepSeekService(api_key=api_key)
|
| 376 |
+
await _deepseek_service.load()
|
| 377 |
+
if not _deepseek_service.is_loaded:
|
| 378 |
+
print("DeepSeek API service initialization failed, will use fallback mode")
|
| 379 |
+
return _deepseek_service # Return instance but marked as not loaded
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"DeepSeek service initialization failed: {e}")
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
return _deepseek_service
|
inference/.ipynb_checkpoints/demo-checkpoint.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 3 |
+
from qwen_vl_utils import process_vision_info
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
# === Configuration ===
|
| 7 |
+
MODEL_PATH = "../checkpoint"
|
| 8 |
+
IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
|
| 9 |
+
PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
print(f"Loading model from {MODEL_PATH}...")
|
| 13 |
+
|
| 14 |
+
# 1. Load Model
|
| 15 |
+
try:
|
| 16 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 17 |
+
MODEL_PATH,
|
| 18 |
+
torch_dtype=torch.bfloat16,
|
| 19 |
+
device_map="auto",
|
| 20 |
+
trust_remote_code=True
|
| 21 |
+
)
|
| 22 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error loading model: {e}")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
# 2. Check Image
|
| 28 |
+
import os
|
| 29 |
+
if not os.path.exists(IMAGE_PATH):
|
| 30 |
+
print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
|
| 31 |
+
# Create a dummy image for code demonstration purposes if needed, or just return
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
# 3. Prepare Inputs
|
| 35 |
+
messages = [
|
| 36 |
+
{
|
| 37 |
+
"role": "user",
|
| 38 |
+
"content": [
|
| 39 |
+
{"type": "image", "image": IMAGE_PATH},
|
| 40 |
+
{"type": "text", "text": PROMPT},
|
| 41 |
+
],
|
| 42 |
+
}
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
print("Processing...")
|
| 46 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 47 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
+
|
| 49 |
+
inputs = processor(
|
| 50 |
+
text=[text],
|
| 51 |
+
images=image_inputs,
|
| 52 |
+
videos=video_inputs,
|
| 53 |
+
padding=True,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
).to(model.device)
|
| 56 |
+
|
| 57 |
+
# 4. Generate
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
generated_ids = model.generate(
|
| 60 |
+
**inputs,
|
| 61 |
+
max_new_tokens=1024,
|
| 62 |
+
temperature=0.7,
|
| 63 |
+
top_p=0.9
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# 5. Decode
|
| 67 |
+
output_text = processor.batch_decode(
|
| 68 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
print("\n=== Diagnosis Result ===")
|
| 72 |
+
print(output_text[0])
|
| 73 |
+
print("========================")
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
inference/.ipynb_checkpoints/inference-checkpoint.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
from model_utils import SkinGPTModel
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
|
| 8 |
+
parser.add_argument("--image", type=str, required=True, help="Path to the image")
|
| 9 |
+
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 10 |
+
parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
if not os.path.exists(args.image):
|
| 14 |
+
print(f"Error: Image not found at {args.image}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
# 1. 加载模型 (复用 model_utils)
|
| 18 |
+
# 这样你就不用在这里重复写 transformers 的加载代码了
|
| 19 |
+
bot = SkinGPTModel(args.model_path)
|
| 20 |
+
|
| 21 |
+
# 2. 构造单轮消息
|
| 22 |
+
system_prompt = "You are a professional AI dermatology assistant."
|
| 23 |
+
messages = [
|
| 24 |
+
{
|
| 25 |
+
"role": "user",
|
| 26 |
+
"content": [
|
| 27 |
+
{"type": "image", "image": args.image},
|
| 28 |
+
{"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# 3. 推理
|
| 34 |
+
print(f"\nAnalyzing {args.image}...")
|
| 35 |
+
response = bot.generate_response(messages)
|
| 36 |
+
|
| 37 |
+
print("-" * 40)
|
| 38 |
+
print("Result:")
|
| 39 |
+
print(response)
|
| 40 |
+
print("-" * 40)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
inference/.ipynb_checkpoints/model_utils-checkpoint.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_utils.py
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
|
| 4 |
+
from qwen_vl_utils import process_vision_info
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
from threading import Thread
|
| 8 |
+
|
| 9 |
+
class SkinGPTModel:
|
| 10 |
+
def __init__(self, model_path, device=None):
|
| 11 |
+
self.model_path = model_path
|
| 12 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
print(f"Loading model from {model_path} on {self.device}...")
|
| 14 |
+
|
| 15 |
+
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 16 |
+
model_path,
|
| 17 |
+
torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
|
| 18 |
+
attn_implementation="flash_attention_2" if self.device == "cuda" else None,
|
| 19 |
+
device_map="auto" if self.device != "mps" else None,
|
| 20 |
+
trust_remote_code=True
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if self.device == "mps":
|
| 24 |
+
self.model = self.model.to(self.device)
|
| 25 |
+
|
| 26 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 27 |
+
model_path,
|
| 28 |
+
trust_remote_code=True,
|
| 29 |
+
min_pixels=256*28*28,
|
| 30 |
+
max_pixels=1280*28*28
|
| 31 |
+
)
|
| 32 |
+
print("Model loaded successfully.")
|
| 33 |
+
|
| 34 |
+
def generate_response(self, messages, max_new_tokens=1024, temperature=0.7):
|
| 35 |
+
"""
|
| 36 |
+
处理多轮对话的历史消息列表并生成回复
|
| 37 |
+
messages format:
|
| 38 |
+
[
|
| 39 |
+
{'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
|
| 40 |
+
{'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
|
| 41 |
+
]
|
| 42 |
+
"""
|
| 43 |
+
# 预处理文本模板
|
| 44 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 45 |
+
|
| 46 |
+
# 预处理视觉信息
|
| 47 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
+
|
| 49 |
+
inputs = self.processor(
|
| 50 |
+
text=[text],
|
| 51 |
+
images=image_inputs,
|
| 52 |
+
videos=video_inputs,
|
| 53 |
+
padding=True,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
).to(self.model.device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
generated_ids = self.model.generate(
|
| 59 |
+
**inputs,
|
| 60 |
+
max_new_tokens=max_new_tokens,
|
| 61 |
+
temperature=temperature,
|
| 62 |
+
top_p=0.9,
|
| 63 |
+
do_sample=True
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# 解码输出 (去除输入的token)
|
| 67 |
+
generated_ids_trimmed = [
|
| 68 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 69 |
+
]
|
| 70 |
+
output_text = self.processor.batch_decode(
|
| 71 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return output_text[0]
|
| 75 |
+
|
| 76 |
+
def generate_response_stream(self, messages, max_new_tokens=2048, temperature=0.7):
|
| 77 |
+
"""
|
| 78 |
+
流式生成响应
|
| 79 |
+
返回一个生成器,逐个yield生成的文本chunk
|
| 80 |
+
"""
|
| 81 |
+
# 预处理文本模板
|
| 82 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 83 |
+
|
| 84 |
+
# 预处理视觉信息
|
| 85 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 86 |
+
|
| 87 |
+
inputs = self.processor(
|
| 88 |
+
text=[text],
|
| 89 |
+
images=image_inputs,
|
| 90 |
+
videos=video_inputs,
|
| 91 |
+
padding=True,
|
| 92 |
+
return_tensors="pt",
|
| 93 |
+
).to(self.model.device)
|
| 94 |
+
|
| 95 |
+
# 创建 TextIteratorStreamer 用于流式输出
|
| 96 |
+
streamer = TextIteratorStreamer(
|
| 97 |
+
self.processor.tokenizer,
|
| 98 |
+
skip_prompt=True,
|
| 99 |
+
skip_special_tokens=True
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 准备生成参数
|
| 103 |
+
generation_kwargs = {
|
| 104 |
+
**inputs,
|
| 105 |
+
"max_new_tokens": max_new_tokens,
|
| 106 |
+
"temperature": temperature,
|
| 107 |
+
"top_p": 0.9,
|
| 108 |
+
"do_sample": True,
|
| 109 |
+
"streamer": streamer,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# 在单独的线程中运行生成
|
| 113 |
+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 114 |
+
thread.start()
|
| 115 |
+
|
| 116 |
+
# 逐个yield生成的文本
|
| 117 |
+
for text_chunk in streamer:
|
| 118 |
+
yield text_chunk
|
| 119 |
+
|
| 120 |
+
thread.join()
|
inference/__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
inference/__pycache__/deepseek_service.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
inference/__pycache__/model_utils.cpython-311.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
inference/app.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import uvicorn
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import uuid
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
from contextlib import asynccontextmanager
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import StreamingResponse
|
| 16 |
+
from fastapi.concurrency import run_in_threadpool
|
| 17 |
+
from model_utils import SkinGPTModel
|
| 18 |
+
from deepseek_service import get_deepseek_service, DeepSeekService
|
| 19 |
+
|
| 20 |
+
# === Configuration ===
|
| 21 |
+
MODEL_PATH = "../checkpoint"
|
| 22 |
+
TEMP_DIR = "./temp_uploads"
|
| 23 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# DeepSeek API Key
|
| 26 |
+
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
|
| 27 |
+
|
| 28 |
+
# Global DeepSeek service instance
|
| 29 |
+
deepseek_service: Optional[DeepSeekService] = None
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def lifespan(app: FastAPI):
|
| 33 |
+
"""应用生命周期管理"""
|
| 34 |
+
# 启动时初始化 DeepSeek 服务
|
| 35 |
+
await init_deepseek()
|
| 36 |
+
yield
|
| 37 |
+
print("\nShutting down service...")
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="SkinGPT-R1 皮肤诊断系统",
|
| 41 |
+
description="智能皮肤诊断助手",
|
| 42 |
+
version="1.0.0",
|
| 43 |
+
lifespan=lifespan
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# CORS配置 - 允许前端访问
|
| 47 |
+
app.add_middleware(
|
| 48 |
+
CORSMiddleware,
|
| 49 |
+
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 50 |
+
allow_credentials=True,
|
| 51 |
+
allow_methods=["*"],
|
| 52 |
+
allow_headers=["*"],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 全局变量存储状态
|
| 56 |
+
# chat_states: 存储对话历史 (List of messages for Qwen)
|
| 57 |
+
# pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
|
| 58 |
+
chat_states = {}
|
| 59 |
+
pending_images = {}
|
| 60 |
+
|
| 61 |
+
def parse_diagnosis_result(raw_text: str) -> dict:
|
| 62 |
+
"""
|
| 63 |
+
解析诊断结果中的think和answer标签
|
| 64 |
+
|
| 65 |
+
参数:
|
| 66 |
+
- raw_text: 原始诊断文本
|
| 67 |
+
|
| 68 |
+
返回:
|
| 69 |
+
- dict: 包含thinking, answer, raw字段的字典
|
| 70 |
+
"""
|
| 71 |
+
import re
|
| 72 |
+
|
| 73 |
+
# 尝试匹配完整的标签
|
| 74 |
+
think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
|
| 75 |
+
answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
|
| 76 |
+
|
| 77 |
+
thinking = None
|
| 78 |
+
answer = None
|
| 79 |
+
|
| 80 |
+
# 处理think标签
|
| 81 |
+
if think_match:
|
| 82 |
+
thinking = think_match.group(1).strip()
|
| 83 |
+
else:
|
| 84 |
+
# 尝试匹配未闭合的think标签(输出被截断的情况)
|
| 85 |
+
unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
|
| 86 |
+
if unclosed_think:
|
| 87 |
+
thinking = unclosed_think.group(1).strip()
|
| 88 |
+
|
| 89 |
+
# 处理answer标签
|
| 90 |
+
if answer_match:
|
| 91 |
+
answer = answer_match.group(1).strip()
|
| 92 |
+
else:
|
| 93 |
+
# 尝试匹配未闭合的answer标签
|
| 94 |
+
unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
|
| 95 |
+
if unclosed_answer:
|
| 96 |
+
answer = unclosed_answer.group(1).strip()
|
| 97 |
+
|
| 98 |
+
# 如果仍然没有找到answer,清理原始文本作为answer
|
| 99 |
+
if not answer:
|
| 100 |
+
# 移除所有标签及其内容
|
| 101 |
+
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
|
| 102 |
+
cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
|
| 103 |
+
cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
|
| 104 |
+
cleaned = cleaned.strip()
|
| 105 |
+
answer = cleaned if cleaned else raw_text
|
| 106 |
+
|
| 107 |
+
# 清理可能残留的标签
|
| 108 |
+
if answer:
|
| 109 |
+
answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
|
| 110 |
+
if thinking:
|
| 111 |
+
thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
|
| 112 |
+
|
| 113 |
+
# 处理 "Final Answer:" 格式,提取其后的内容
|
| 114 |
+
if answer:
|
| 115 |
+
final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
|
| 116 |
+
if final_answer_match:
|
| 117 |
+
answer = final_answer_match.group(1).strip()
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"thinking": thinking if thinking else None,
|
| 121 |
+
"answer": answer,
|
| 122 |
+
"raw": raw_text
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
print("Initializing Model Service...")
|
| 126 |
+
# 全局加载模型
|
| 127 |
+
gpt_model = SkinGPTModel(MODEL_PATH)
|
| 128 |
+
print("Service Ready.")
|
| 129 |
+
|
| 130 |
+
# 初始化 DeepSeek 服务(异步)
|
| 131 |
+
async def init_deepseek():
|
| 132 |
+
global deepseek_service
|
| 133 |
+
print("\nInitializing DeepSeek service...")
|
| 134 |
+
deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
|
| 135 |
+
if deepseek_service and deepseek_service.is_loaded:
|
| 136 |
+
print("DeepSeek service is ready!")
|
| 137 |
+
else:
|
| 138 |
+
print("DeepSeek service not available, will return raw results")
|
| 139 |
+
|
| 140 |
+
@app.post("/v1/upload/{state_id}")
|
| 141 |
+
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 142 |
+
"""
|
| 143 |
+
接收图片上传。
|
| 144 |
+
逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
|
| 145 |
+
"""
|
| 146 |
+
try:
|
| 147 |
+
# 1. 保存图片到本地临时文件
|
| 148 |
+
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 149 |
+
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 150 |
+
file_path = os.path.join(TEMP_DIR, unique_name)
|
| 151 |
+
|
| 152 |
+
with open(file_path, "wb") as buffer:
|
| 153 |
+
shutil.copyfileobj(file.file, buffer)
|
| 154 |
+
|
| 155 |
+
# 2. 记录图片路径等待下一次 predict 调用时使用
|
| 156 |
+
# 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
|
| 157 |
+
pending_images[state_id] = file_path
|
| 158 |
+
|
| 159 |
+
# 3. 初始化对话状态(如果是新会话)
|
| 160 |
+
if state_id not in chat_states:
|
| 161 |
+
chat_states[state_id] = []
|
| 162 |
+
|
| 163 |
+
return {"message": "Image uploaded successfully", "path": file_path}
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
|
| 167 |
+
|
| 168 |
+
@app.post("/v1/predict/{state_id}")
|
| 169 |
+
async def v1_predict(request: Request, state_id: str):
|
| 170 |
+
"""
|
| 171 |
+
接收文本并执行推理。
|
| 172 |
+
逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
|
| 173 |
+
"""
|
| 174 |
+
try:
|
| 175 |
+
data = await request.json()
|
| 176 |
+
except:
|
| 177 |
+
raise HTTPException(status_code=400, detail="Invalid JSON")
|
| 178 |
+
|
| 179 |
+
user_message = data.get("message", "")
|
| 180 |
+
if not user_message:
|
| 181 |
+
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 182 |
+
|
| 183 |
+
# 获取或初始化历史
|
| 184 |
+
history = chat_states.get(state_id, [])
|
| 185 |
+
|
| 186 |
+
# 构建当前轮次的用户内容
|
| 187 |
+
current_content = []
|
| 188 |
+
|
| 189 |
+
# 1. 检查是否有刚刚上传的图片
|
| 190 |
+
if state_id in pending_images:
|
| 191 |
+
img_path = pending_images.pop(state_id) # 取出并移除
|
| 192 |
+
current_content.append({"type": "image", "image": img_path})
|
| 193 |
+
|
| 194 |
+
# 如果是第一次对话,加上 System Prompt
|
| 195 |
+
if not history:
|
| 196 |
+
system_prompt = "You are a professional AI dermatology assistant. "
|
| 197 |
+
user_message = f"{system_prompt}\n\n{user_message}"
|
| 198 |
+
|
| 199 |
+
# 2. 添加文本
|
| 200 |
+
current_content.append({"type": "text", "text": user_message})
|
| 201 |
+
|
| 202 |
+
# 3. 更新历史
|
| 203 |
+
history.append({"role": "user", "content": current_content})
|
| 204 |
+
chat_states[state_id] = history
|
| 205 |
+
|
| 206 |
+
# 4. 运行推理 (在线程池中运行以防阻塞)
|
| 207 |
+
try:
|
| 208 |
+
response_text = await run_in_threadpool(
|
| 209 |
+
gpt_model.generate_response,
|
| 210 |
+
messages=history
|
| 211 |
+
)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
# 回滚历史(移除刚才出错的用户提问)
|
| 214 |
+
chat_states[state_id].pop()
|
| 215 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
| 216 |
+
|
| 217 |
+
# 5. 将回复加入历史
|
| 218 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 219 |
+
chat_states[state_id] = history
|
| 220 |
+
|
| 221 |
+
return {"message": response_text}
|
| 222 |
+
|
| 223 |
+
@app.post("/v1/reset/{state_id}")
|
| 224 |
+
async def reset_chat(state_id: str):
|
| 225 |
+
"""清除会话状态"""
|
| 226 |
+
if state_id in chat_states:
|
| 227 |
+
del chat_states[state_id]
|
| 228 |
+
if state_id in pending_images:
|
| 229 |
+
# 可选:删除临时文件
|
| 230 |
+
try:
|
| 231 |
+
os.remove(pending_images[state_id])
|
| 232 |
+
except:
|
| 233 |
+
pass
|
| 234 |
+
del pending_images[state_id]
|
| 235 |
+
return {"message": "Chat history reset"}
|
| 236 |
+
|
| 237 |
+
@app.get("/")
|
| 238 |
+
async def root():
|
| 239 |
+
"""根路径"""
|
| 240 |
+
return {
|
| 241 |
+
"name": "SkinGPT-R1 皮肤诊断系统",
|
| 242 |
+
"version": "1.0.0",
|
| 243 |
+
"status": "running",
|
| 244 |
+
"description": "智能皮肤诊断助手"
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
@app.get("/health")
|
| 248 |
+
async def health_check():
|
| 249 |
+
"""健康检查"""
|
| 250 |
+
return {
|
| 251 |
+
"status": "healthy",
|
| 252 |
+
"model_loaded": True
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
@app.post("/diagnose/stream")
|
| 256 |
+
async def diagnose_stream(
|
| 257 |
+
image: Optional[UploadFile] = File(None),
|
| 258 |
+
text: str = Form(...),
|
| 259 |
+
language: str = Form("zh"),
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
SSE流式诊断接口(用于前端)
|
| 263 |
+
支持图片上传和文本输入,返回真正的流式响应
|
| 264 |
+
使用 DeepSeek API 优化输出格式
|
| 265 |
+
"""
|
| 266 |
+
from queue import Queue, Empty
|
| 267 |
+
from threading import Thread
|
| 268 |
+
|
| 269 |
+
language = language if language in ("zh", "en") else "zh"
|
| 270 |
+
|
| 271 |
+
# 处理图片
|
| 272 |
+
pil_image = None
|
| 273 |
+
temp_image_path = None
|
| 274 |
+
|
| 275 |
+
if image:
|
| 276 |
+
contents = await image.read()
|
| 277 |
+
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 278 |
+
|
| 279 |
+
# 创建队列用于线程间通信
|
| 280 |
+
result_queue = Queue()
|
| 281 |
+
# 用于存储完整响应和解析结果
|
| 282 |
+
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 283 |
+
|
| 284 |
+
def run_generation():
|
| 285 |
+
"""在后台线程中运行流式生成"""
|
| 286 |
+
full_response = []
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# 构建消息
|
| 290 |
+
messages = []
|
| 291 |
+
current_content = []
|
| 292 |
+
|
| 293 |
+
# 添加系统提示
|
| 294 |
+
system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
|
| 295 |
+
|
| 296 |
+
# 如果有图片,保存到临时文件
|
| 297 |
+
if pil_image:
|
| 298 |
+
generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
|
| 299 |
+
pil_image.save(generation_result["temp_image_path"])
|
| 300 |
+
current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
|
| 301 |
+
|
| 302 |
+
# 添加文本
|
| 303 |
+
prompt = f"{system_prompt}\n\n{text}"
|
| 304 |
+
current_content.append({"type": "text", "text": prompt})
|
| 305 |
+
messages.append({"role": "user", "content": current_content})
|
| 306 |
+
|
| 307 |
+
# 流式生成 - 每个 chunk 立即放入队列
|
| 308 |
+
for chunk in gpt_model.generate_response_stream(
|
| 309 |
+
messages=messages,
|
| 310 |
+
max_new_tokens=2048,
|
| 311 |
+
temperature=0.7
|
| 312 |
+
):
|
| 313 |
+
full_response.append(chunk)
|
| 314 |
+
result_queue.put(("delta", chunk))
|
| 315 |
+
|
| 316 |
+
# 解析结果
|
| 317 |
+
response_text = "".join(full_response)
|
| 318 |
+
parsed = parse_diagnosis_result(response_text)
|
| 319 |
+
generation_result["full_response"] = full_response
|
| 320 |
+
generation_result["parsed"] = parsed
|
| 321 |
+
|
| 322 |
+
# 标记生成完成
|
| 323 |
+
result_queue.put(("generation_done", None))
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
result_queue.put(("error", str(e)))
|
| 327 |
+
|
| 328 |
+
async def event_generator():
|
| 329 |
+
"""异步生成SSE事件"""
|
| 330 |
+
# 在后台线程启动生成(非阻塞)
|
| 331 |
+
gen_thread = Thread(target=run_generation)
|
| 332 |
+
gen_thread.start()
|
| 333 |
+
|
| 334 |
+
loop = asyncio.get_event_loop()
|
| 335 |
+
|
| 336 |
+
# 从队列中读取并发送流式内容
|
| 337 |
+
while True:
|
| 338 |
+
try:
|
| 339 |
+
# 非阻塞获取
|
| 340 |
+
msg_type, data = await loop.run_in_executor(
|
| 341 |
+
None,
|
| 342 |
+
lambda: result_queue.get(timeout=0.1)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if msg_type == "generation_done":
|
| 346 |
+
# 流式生成完成,准备处理最终结果
|
| 347 |
+
break
|
| 348 |
+
elif msg_type == "delta":
|
| 349 |
+
yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
|
| 350 |
+
yield f"data: {yield_chunk}\n\n"
|
| 351 |
+
elif msg_type == "error":
|
| 352 |
+
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 353 |
+
gen_thread.join()
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
except Empty:
|
| 357 |
+
# 队列暂时为空,继续等待
|
| 358 |
+
await asyncio.sleep(0.01)
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
gen_thread.join()
|
| 362 |
+
|
| 363 |
+
# 获取解析结果
|
| 364 |
+
parsed = generation_result["parsed"]
|
| 365 |
+
if not parsed:
|
| 366 |
+
yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
raw_thinking = parsed["thinking"]
|
| 370 |
+
raw_answer = parsed["answer"]
|
| 371 |
+
|
| 372 |
+
# 使用 DeepSeek 优化结果
|
| 373 |
+
refined_by_deepseek = False
|
| 374 |
+
description = None
|
| 375 |
+
thinking = raw_thinking
|
| 376 |
+
answer = raw_answer
|
| 377 |
+
|
| 378 |
+
if deepseek_service and deepseek_service.is_loaded:
|
| 379 |
+
try:
|
| 380 |
+
print(f"Calling DeepSeek to refine diagnosis (language={language})...")
|
| 381 |
+
refined = await deepseek_service.refine_diagnosis(
|
| 382 |
+
raw_answer=raw_answer,
|
| 383 |
+
raw_thinking=raw_thinking,
|
| 384 |
+
language=language,
|
| 385 |
+
)
|
| 386 |
+
if refined["success"]:
|
| 387 |
+
description = refined["description"]
|
| 388 |
+
thinking = refined["analysis_process"]
|
| 389 |
+
answer = refined["diagnosis_result"]
|
| 390 |
+
refined_by_deepseek = True
|
| 391 |
+
print(f"DeepSeek refinement completed successfully")
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"DeepSeek refinement failed, using original: {e}")
|
| 394 |
+
else:
|
| 395 |
+
print("DeepSeek service not available, using raw results")
|
| 396 |
+
|
| 397 |
+
success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
|
| 398 |
+
|
| 399 |
+
# 返回格式与参考项目保持一致
|
| 400 |
+
final_payload = {
|
| 401 |
+
"description": description, # 图片描述(从 thinking 中提取)
|
| 402 |
+
"thinking": thinking, # 分析过程(DeepSeek 优化后)
|
| 403 |
+
"answer": answer, # 诊断结果(DeepSeek 优化后)
|
| 404 |
+
"raw": parsed["raw"], # 原始响应
|
| 405 |
+
"refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
|
| 406 |
+
"success": True,
|
| 407 |
+
"message": success_msg
|
| 408 |
+
}
|
| 409 |
+
yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
|
| 410 |
+
yield f"data: {yield_final}\n\n"
|
| 411 |
+
|
| 412 |
+
# 清理临时图片
|
| 413 |
+
temp_path = generation_result.get("temp_image_path")
|
| 414 |
+
if temp_path and os.path.exists(temp_path):
|
| 415 |
+
try:
|
| 416 |
+
os.remove(temp_path)
|
| 417 |
+
except:
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 421 |
+
|
| 422 |
+
if __name__ == '__main__':
|
| 423 |
+
uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
|
inference/chat.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chat.py
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
from model_utils import SkinGPTModel
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
|
| 8 |
+
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 9 |
+
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
# 初始化模型
|
| 13 |
+
bot = SkinGPTModel(args.model_path)
|
| 14 |
+
|
| 15 |
+
# 初始化对话历史
|
| 16 |
+
# 系统提示词
|
| 17 |
+
system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
|
| 18 |
+
|
| 19 |
+
# 构造第一条包含图片的消息
|
| 20 |
+
if not os.path.exists(args.image):
|
| 21 |
+
print(f"Error: Image {args.image} not found.")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
history = [
|
| 25 |
+
{
|
| 26 |
+
"role": "user",
|
| 27 |
+
"content": [
|
| 28 |
+
{"type": "image", "image": args.image},
|
| 29 |
+
{"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
|
| 35 |
+
print(f"Image loaded: {args.image}")
|
| 36 |
+
|
| 37 |
+
# 获取第一轮诊断
|
| 38 |
+
print("\nModel is thinking...", end="", flush=True)
|
| 39 |
+
response = bot.generate_response(history)
|
| 40 |
+
print(f"\rAssistant: {response}\n")
|
| 41 |
+
|
| 42 |
+
# 将助手的回复加入历史
|
| 43 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 44 |
+
|
| 45 |
+
# 进入多轮对话循环
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
user_input = input("User: ")
|
| 49 |
+
if user_input.lower() in ["exit", "quit"]:
|
| 50 |
+
break
|
| 51 |
+
if not user_input.strip():
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# 加入用户的新问题
|
| 55 |
+
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
| 56 |
+
|
| 57 |
+
print("Model is thinking...", end="", flush=True)
|
| 58 |
+
response = bot.generate_response(history)
|
| 59 |
+
print(f"\rAssistant: {response}\n")
|
| 60 |
+
|
| 61 |
+
# 加入助手的新回复
|
| 62 |
+
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 63 |
+
|
| 64 |
+
except KeyboardInterrupt:
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|
inference/deepseek_service.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepSeek API Service
|
| 3 |
+
Used to optimize and organize SkinGPT model output results
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from openai import AsyncOpenAI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DeepSeekService:
|
| 13 |
+
"""DeepSeek API Service Class"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 16 |
+
"""
|
| 17 |
+
Initialize DeepSeek service
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
api_key: DeepSeek API key, reads from environment variable if not provided
|
| 21 |
+
"""
|
| 22 |
+
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
| 23 |
+
self.base_url = "https://api.deepseek.com"
|
| 24 |
+
self.model = "deepseek-chat" # Using deepseek-chat model
|
| 25 |
+
|
| 26 |
+
self.client = None
|
| 27 |
+
self.is_loaded = False
|
| 28 |
+
|
| 29 |
+
print(f"DeepSeek API service initializing...")
|
| 30 |
+
print(f"API Base URL: {self.base_url}")
|
| 31 |
+
|
| 32 |
+
async def load(self):
|
| 33 |
+
"""Initialize DeepSeek API client"""
|
| 34 |
+
try:
|
| 35 |
+
if not self.api_key:
|
| 36 |
+
print("DeepSeek API key not provided")
|
| 37 |
+
self.is_loaded = False
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# Initialize OpenAI compatible client
|
| 41 |
+
self.client = AsyncOpenAI(
|
| 42 |
+
api_key=self.api_key,
|
| 43 |
+
base_url=self.base_url
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.is_loaded = True
|
| 47 |
+
print("DeepSeek API service is ready!")
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"DeepSeek API service initialization failed: {e}")
|
| 51 |
+
self.is_loaded = False
|
| 52 |
+
|
| 53 |
+
async def refine_diagnosis(
|
| 54 |
+
self,
|
| 55 |
+
raw_answer: str,
|
| 56 |
+
raw_thinking: Optional[str] = None,
|
| 57 |
+
language: str = "zh"
|
| 58 |
+
) -> dict:
|
| 59 |
+
"""
|
| 60 |
+
Use DeepSeek API to optimize and organize diagnosis results
|
| 61 |
+
|
| 62 |
+
Parameters:
|
| 63 |
+
raw_answer: Original diagnosis result
|
| 64 |
+
raw_thinking: AI thinking process
|
| 65 |
+
language: Language option
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Dictionary containing "description", "analysis_process" and "diagnosis_result"
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
if not self.is_loaded or self.client is None:
|
| 72 |
+
error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
|
| 73 |
+
print("DeepSeek API not initialized, returning original result")
|
| 74 |
+
return {
|
| 75 |
+
"success": False,
|
| 76 |
+
"description": "",
|
| 77 |
+
"analysis_process": raw_thinking or error_msg,
|
| 78 |
+
"diagnosis_result": raw_answer,
|
| 79 |
+
"original_diagnosis": raw_answer,
|
| 80 |
+
"error": "DeepSeek API not initialized"
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Build prompt
|
| 85 |
+
prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
|
| 86 |
+
|
| 87 |
+
# Select system prompt based on language
|
| 88 |
+
if language == "en":
|
| 89 |
+
system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
|
| 90 |
+
else:
|
| 91 |
+
system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
|
| 92 |
+
|
| 93 |
+
# Call DeepSeek API
|
| 94 |
+
response = await self.client.chat.completions.create(
|
| 95 |
+
model=self.model,
|
| 96 |
+
messages=[
|
| 97 |
+
{"role": "system", "content": system_content},
|
| 98 |
+
{"role": "user", "content": prompt}
|
| 99 |
+
],
|
| 100 |
+
temperature=0.1,
|
| 101 |
+
max_tokens=2048,
|
| 102 |
+
top_p=0.8,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Extract generated text
|
| 106 |
+
generated_text = response.choices[0].message.content
|
| 107 |
+
|
| 108 |
+
# Parse output
|
| 109 |
+
parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"success": True,
|
| 113 |
+
"description": parsed["description"],
|
| 114 |
+
"analysis_process": parsed["analysis_process"],
|
| 115 |
+
"diagnosis_result": parsed["diagnosis_result"],
|
| 116 |
+
"original_diagnosis": raw_answer,
|
| 117 |
+
"raw_refined": generated_text
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"DeepSeek API call failed: {e}")
|
| 122 |
+
error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
|
| 123 |
+
return {
|
| 124 |
+
"success": False,
|
| 125 |
+
"description": "",
|
| 126 |
+
"analysis_process": raw_thinking or error_msg,
|
| 127 |
+
"diagnosis_result": raw_answer,
|
| 128 |
+
"original_diagnosis": raw_answer,
|
| 129 |
+
"error": str(e)
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
|
| 133 |
+
"""
|
| 134 |
+
Build optimization prompt
|
| 135 |
+
|
| 136 |
+
Parameters:
|
| 137 |
+
raw_answer: Original diagnosis result
|
| 138 |
+
raw_thinking: AI thinking process
|
| 139 |
+
language: Language option, "zh" for Chinese, "en" for English
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Built prompt
|
| 143 |
+
"""
|
| 144 |
+
if language == "en":
|
| 145 |
+
# English prompt - organize and polish while preserving meaning
|
| 146 |
+
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 147 |
+
prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
|
| 148 |
+
|
| 149 |
+
【Requirements】
|
| 150 |
+
- Preserve the original tone and expression style
|
| 151 |
+
- Text 1 contains the thinking process, Text 2 contains the diagnosis result
|
| 152 |
+
- Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
|
| 153 |
+
- For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
|
| 154 |
+
- If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
|
| 155 |
+
- If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
|
| 156 |
+
- DO NOT infer or add new medical information, DO NOT output any meta-commentary
|
| 157 |
+
- You may adjust unreasonable statements or remove redundant content to improve clarity
|
| 158 |
+
|
| 159 |
+
[Text 1]
|
| 160 |
+
{thinking_text}
|
| 161 |
+
|
| 162 |
+
[Text 2]
|
| 163 |
+
{raw_answer}
|
| 164 |
+
|
| 165 |
+
【Output】Only output three sections, do not output anything else:
|
| 166 |
+
## Description
|
| 167 |
+
(Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
|
| 168 |
+
|
| 169 |
+
## Analysis Process
|
| 170 |
+
(Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
|
| 171 |
+
|
| 172 |
+
## Diagnosis Result
|
| 173 |
+
(The organized diagnosis result from Text 2)
|
| 174 |
+
|
| 175 |
+
【Example】:
|
| 176 |
+
## Description
|
| 177 |
+
The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
|
| 178 |
+
|
| 179 |
+
## Analysis Process
|
| 180 |
+
These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
|
| 181 |
+
|
| 182 |
+
## Diagnosis Result
|
| 183 |
+
Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
|
| 184 |
+
"""
|
| 185 |
+
else:
|
| 186 |
+
# Chinese prompt - translate to Simplified Chinese AND organize/polish
|
| 187 |
+
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 188 |
+
prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
|
| 189 |
+
|
| 190 |
+
【要求】
|
| 191 |
+
- 保留原文的语气和表达方式
|
| 192 |
+
- 文本1是思考过程,文本2是诊断结果
|
| 193 |
+
- 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
|
| 194 |
+
- 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
|
| 195 |
+
- 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
|
| 196 |
+
- 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
|
| 197 |
+
- 【文本1】和【文本2】需要翻译成简体中文
|
| 198 |
+
- 禁止推断或添加新的医学信息,禁止输出任何元评论
|
| 199 |
+
- 可以调整不合理的语句或去除冗余内容以提高清晰度
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
【文本1】
|
| 203 |
+
{thinking_text}
|
| 204 |
+
|
| 205 |
+
【文本2】
|
| 206 |
+
{raw_answer}
|
| 207 |
+
|
| 208 |
+
【输出】只输出三个部分,不要输出其他任何内容:
|
| 209 |
+
## 图像描述
|
| 210 |
+
(从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
|
| 211 |
+
|
| 212 |
+
## 分析过程
|
| 213 |
+
(提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
|
| 214 |
+
|
| 215 |
+
## 诊断结果
|
| 216 |
+
(整理后的诊断结果)
|
| 217 |
+
|
| 218 |
+
【样例】:
|
| 219 |
+
## 图像描述
|
| 220 |
+
图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
|
| 221 |
+
|
| 222 |
+
## 分析过程
|
| 223 |
+
这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
|
| 224 |
+
|
| 225 |
+
## 诊断结果
|
| 226 |
+
可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
return prompt
|
| 230 |
+
|
| 231 |
+
def _parse_refined_output(
|
| 232 |
+
self,
|
| 233 |
+
generated_text: str,
|
| 234 |
+
raw_answer: str,
|
| 235 |
+
raw_thinking: Optional[str] = None,
|
| 236 |
+
language: str = "zh"
|
| 237 |
+
) -> dict:
|
| 238 |
+
"""
|
| 239 |
+
Parse DeepSeek generated output
|
| 240 |
+
|
| 241 |
+
Parameters:
|
| 242 |
+
generated_text: DeepSeek generated text
|
| 243 |
+
raw_answer: Original diagnosis (as fallback)
|
| 244 |
+
raw_thinking: Original thinking process (as fallback)
|
| 245 |
+
language: Language option
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Dictionary containing description, analysis_process and diagnosis_result
|
| 249 |
+
"""
|
| 250 |
+
description = ""
|
| 251 |
+
analysis_process = None
|
| 252 |
+
diagnosis_result = None
|
| 253 |
+
|
| 254 |
+
if language == "en":
|
| 255 |
+
# English patterns
|
| 256 |
+
desc_match = re.search(
|
| 257 |
+
r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
|
| 258 |
+
generated_text,
|
| 259 |
+
re.IGNORECASE
|
| 260 |
+
)
|
| 261 |
+
analysis_match = re.search(
|
| 262 |
+
r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
|
| 263 |
+
generated_text,
|
| 264 |
+
re.IGNORECASE
|
| 265 |
+
)
|
| 266 |
+
result_match = re.search(
|
| 267 |
+
r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
|
| 268 |
+
generated_text,
|
| 269 |
+
re.IGNORECASE
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
desc_header = "## Description"
|
| 273 |
+
analysis_header = "## Analysis Process"
|
| 274 |
+
result_header = "## Diagnosis Result"
|
| 275 |
+
else:
|
| 276 |
+
# Chinese patterns
|
| 277 |
+
desc_match = re.search(
|
| 278 |
+
r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
|
| 279 |
+
generated_text
|
| 280 |
+
)
|
| 281 |
+
analysis_match = re.search(
|
| 282 |
+
r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
|
| 283 |
+
generated_text
|
| 284 |
+
)
|
| 285 |
+
result_match = re.search(
|
| 286 |
+
r'##\s*诊断结果\s*\n([\s\S]*?)$',
|
| 287 |
+
generated_text
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
desc_header = "## 图像描述"
|
| 291 |
+
analysis_header = "## 分析过程"
|
| 292 |
+
result_header = "## 诊断结果"
|
| 293 |
+
|
| 294 |
+
# Extract description
|
| 295 |
+
if desc_match:
|
| 296 |
+
description = desc_match.group(1).strip()
|
| 297 |
+
print(f"Successfully parsed description")
|
| 298 |
+
else:
|
| 299 |
+
print(f"Description parsing failed")
|
| 300 |
+
description = ""
|
| 301 |
+
|
| 302 |
+
# Extract analysis process
|
| 303 |
+
if analysis_match:
|
| 304 |
+
analysis_process = analysis_match.group(1).strip()
|
| 305 |
+
print(f"Successfully parsed analysis process")
|
| 306 |
+
else:
|
| 307 |
+
print(f"Analysis process parsing failed, trying other methods")
|
| 308 |
+
# Try to extract from generated text
|
| 309 |
+
result_pos = generated_text.find(result_header)
|
| 310 |
+
if result_pos > 0:
|
| 311 |
+
# Get content before diagnosis result
|
| 312 |
+
analysis_process = generated_text[:result_pos].strip()
|
| 313 |
+
# Remove possible headers
|
| 314 |
+
for header in [desc_header, analysis_header]:
|
| 315 |
+
header_escaped = re.escape(header)
|
| 316 |
+
analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
|
| 317 |
+
else:
|
| 318 |
+
# If no format at all, try to get first half
|
| 319 |
+
mid_point = len(generated_text) // 2
|
| 320 |
+
analysis_process = generated_text[:mid_point].strip()
|
| 321 |
+
|
| 322 |
+
# If still empty, use original content (final fallback)
|
| 323 |
+
if not analysis_process and raw_thinking:
|
| 324 |
+
print(f"Using original raw_thinking as fallback")
|
| 325 |
+
analysis_process = raw_thinking
|
| 326 |
+
|
| 327 |
+
# Extract diagnosis result
|
| 328 |
+
if result_match:
|
| 329 |
+
diagnosis_result = result_match.group(1).strip()
|
| 330 |
+
print(f"Successfully parsed diagnosis result")
|
| 331 |
+
else:
|
| 332 |
+
print(f"Diagnosis result parsing failed, trying other methods")
|
| 333 |
+
# Try to extract from generated text
|
| 334 |
+
result_pos = generated_text.find(result_header)
|
| 335 |
+
if result_pos > 0:
|
| 336 |
+
diagnosis_result = generated_text[result_pos:].strip()
|
| 337 |
+
# Remove possible header
|
| 338 |
+
result_header_escaped = re.escape(result_header)
|
| 339 |
+
diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
|
| 340 |
+
else:
|
| 341 |
+
# If no format at all, get second half
|
| 342 |
+
mid_point = len(generated_text) // 2
|
| 343 |
+
diagnosis_result = generated_text[mid_point:].strip()
|
| 344 |
+
|
| 345 |
+
# If still empty, use original content (final fallback)
|
| 346 |
+
if not diagnosis_result:
|
| 347 |
+
print(f"Using original raw_answer as fallback")
|
| 348 |
+
diagnosis_result = raw_answer
|
| 349 |
+
|
| 350 |
+
return {
|
| 351 |
+
"description": description,
|
| 352 |
+
"analysis_process": analysis_process,
|
| 353 |
+
"diagnosis_result": diagnosis_result
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# Global DeepSeek service instance (lazy loading)
|
| 358 |
+
_deepseek_service: Optional[DeepSeekService] = None
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
|
| 362 |
+
"""
|
| 363 |
+
Get DeepSeek service instance (singleton pattern)
|
| 364 |
+
|
| 365 |
+
Parameters:
|
| 366 |
+
api_key: Optional API key to use
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
DeepSeekService instance, or None if API initialization fails
|
| 370 |
+
"""
|
| 371 |
+
global _deepseek_service
|
| 372 |
+
|
| 373 |
+
if _deepseek_service is None:
|
| 374 |
+
try:
|
| 375 |
+
_deepseek_service = DeepSeekService(api_key=api_key)
|
| 376 |
+
await _deepseek_service.load()
|
| 377 |
+
if not _deepseek_service.is_loaded:
|
| 378 |
+
print("DeepSeek API service initialization failed, will use fallback mode")
|
| 379 |
+
return _deepseek_service # Return instance but marked as not loaded
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"DeepSeek service initialization failed: {e}")
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
return _deepseek_service
|
inference/demo.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 3 |
+
from qwen_vl_utils import process_vision_info
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
# === Configuration ===
|
| 7 |
+
MODEL_PATH = "../checkpoint"
|
| 8 |
+
IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
|
| 9 |
+
PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
print(f"Loading model from {MODEL_PATH}...")
|
| 13 |
+
|
| 14 |
+
# 1. Load Model
|
| 15 |
+
try:
|
| 16 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 17 |
+
MODEL_PATH,
|
| 18 |
+
torch_dtype=torch.bfloat16,
|
| 19 |
+
device_map="auto",
|
| 20 |
+
trust_remote_code=True
|
| 21 |
+
)
|
| 22 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error loading model: {e}")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
# 2. Check Image
|
| 28 |
+
import os
|
| 29 |
+
if not os.path.exists(IMAGE_PATH):
|
| 30 |
+
print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
|
| 31 |
+
# Create a dummy image for code demonstration purposes if needed, or just return
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
# 3. Prepare Inputs
|
| 35 |
+
messages = [
|
| 36 |
+
{
|
| 37 |
+
"role": "user",
|
| 38 |
+
"content": [
|
| 39 |
+
{"type": "image", "image": IMAGE_PATH},
|
| 40 |
+
{"type": "text", "text": PROMPT},
|
| 41 |
+
],
|
| 42 |
+
}
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
print("Processing...")
|
| 46 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 47 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
+
|
| 49 |
+
inputs = processor(
|
| 50 |
+
text=[text],
|
| 51 |
+
images=image_inputs,
|
| 52 |
+
videos=video_inputs,
|
| 53 |
+
padding=True,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
).to(model.device)
|
| 56 |
+
|
| 57 |
+
# 4. Generate
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
generated_ids = model.generate(
|
| 60 |
+
**inputs,
|
| 61 |
+
max_new_tokens=1024,
|
| 62 |
+
temperature=0.7,
|
| 63 |
+
repetition_penalty=1.2,
|
| 64 |
+
no_repeat_ngram_size=3,
|
| 65 |
+
top_p=0.9,
|
| 66 |
+
do_sample=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# 5. Decode
|
| 70 |
+
output_text = processor.batch_decode(
|
| 71 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
print("\n=== Diagnosis Result ===")
|
| 75 |
+
print(output_text[0])
|
| 76 |
+
print("========================")
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
inference/inference.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
from model_utils import SkinGPTModel
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
|
| 8 |
+
parser.add_argument("--image", type=str, required=True, help="Path to the image")
|
| 9 |
+
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 10 |
+
parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
if not os.path.exists(args.image):
|
| 14 |
+
print(f"Error: Image not found at {args.image}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
# 1. 加载模型 (复用 model_utils)
|
| 18 |
+
# 这样你就不用在这里重复写 transformers 的加载代码了
|
| 19 |
+
bot = SkinGPTModel(args.model_path)
|
| 20 |
+
|
| 21 |
+
# 2. 构造单轮消息
|
| 22 |
+
system_prompt = "You are a professional AI dermatology assistant."
|
| 23 |
+
messages = [
|
| 24 |
+
{
|
| 25 |
+
"role": "user",
|
| 26 |
+
"content": [
|
| 27 |
+
{"type": "image", "image": args.image},
|
| 28 |
+
{"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# 3. 推理
|
| 34 |
+
print(f"\nAnalyzing {args.image}...")
|
| 35 |
+
response = bot.generate_response(messages)
|
| 36 |
+
|
| 37 |
+
print("-" * 40)
|
| 38 |
+
print("Result:")
|
| 39 |
+
print(response)
|
| 40 |
+
print("-" * 40)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
inference/model_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_utils.py
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
|
| 4 |
+
from qwen_vl_utils import process_vision_info
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
from threading import Thread
|
| 8 |
+
|
| 9 |
+
class SkinGPTModel:
|
| 10 |
+
def __init__(self, model_path, device=None):
|
| 11 |
+
self.model_path = model_path
|
| 12 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
print(f"Loading model from {model_path} on {self.device}...")
|
| 14 |
+
|
| 15 |
+
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 16 |
+
model_path,
|
| 17 |
+
torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
|
| 18 |
+
attn_implementation="flash_attention_2" if self.device == "cuda" else None,
|
| 19 |
+
device_map="auto" if self.device != "mps" else None,
|
| 20 |
+
trust_remote_code=True
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if self.device == "mps":
|
| 24 |
+
self.model = self.model.to(self.device)
|
| 25 |
+
|
| 26 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 27 |
+
model_path,
|
| 28 |
+
trust_remote_code=True,
|
| 29 |
+
min_pixels=256*28*28,
|
| 30 |
+
max_pixels=1280*28*28
|
| 31 |
+
)
|
| 32 |
+
print("Model loaded successfully.")
|
| 33 |
+
|
| 34 |
+
def generate_response(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
|
| 35 |
+
"""
|
| 36 |
+
处理多轮对话的历史消息列表并生成回复
|
| 37 |
+
messages format:
|
| 38 |
+
[
|
| 39 |
+
{'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
|
| 40 |
+
{'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
|
| 41 |
+
]
|
| 42 |
+
"""
|
| 43 |
+
# 预处理文本模板
|
| 44 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 45 |
+
|
| 46 |
+
# 预处理视觉信息
|
| 47 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
+
|
| 49 |
+
inputs = self.processor(
|
| 50 |
+
text=[text],
|
| 51 |
+
images=image_inputs,
|
| 52 |
+
videos=video_inputs,
|
| 53 |
+
padding=True,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
).to(self.model.device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
generated_ids = self.model.generate(
|
| 59 |
+
**inputs,
|
| 60 |
+
max_new_tokens=max_new_tokens,
|
| 61 |
+
temperature=temperature,
|
| 62 |
+
repetition_penalty=repetition_penalty,
|
| 63 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 64 |
+
top_p=0.9,
|
| 65 |
+
do_sample=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 解码输出 (去除输入的token)
|
| 69 |
+
generated_ids_trimmed = [
|
| 70 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 71 |
+
]
|
| 72 |
+
output_text = self.processor.batch_decode(
|
| 73 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return output_text[0]
|
| 77 |
+
|
| 78 |
+
def generate_response_stream(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
|
| 79 |
+
"""
|
| 80 |
+
流式生成响应
|
| 81 |
+
返回一个生成器,逐个yield生成的文本chunk
|
| 82 |
+
"""
|
| 83 |
+
# 预处理文本模板
|
| 84 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 85 |
+
|
| 86 |
+
# 预处理视觉信息
|
| 87 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 88 |
+
|
| 89 |
+
inputs = self.processor(
|
| 90 |
+
text=[text],
|
| 91 |
+
images=image_inputs,
|
| 92 |
+
videos=video_inputs,
|
| 93 |
+
padding=True,
|
| 94 |
+
return_tensors="pt",
|
| 95 |
+
).to(self.model.device)
|
| 96 |
+
|
| 97 |
+
# 创建 TextIteratorStreamer 用于流式输出
|
| 98 |
+
streamer = TextIteratorStreamer(
|
| 99 |
+
self.processor.tokenizer,
|
| 100 |
+
skip_prompt=True,
|
| 101 |
+
skip_special_tokens=True
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# 准备生成参数
|
| 105 |
+
generation_kwargs = {
|
| 106 |
+
**inputs,
|
| 107 |
+
"max_new_tokens": max_new_tokens,
|
| 108 |
+
"temperature": temperature,
|
| 109 |
+
"repetition_penalty": repetition_penalty,
|
| 110 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
| 111 |
+
"top_p": 0.9,
|
| 112 |
+
"do_sample": True,
|
| 113 |
+
"streamer": streamer,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# 在单独的线程中运行生成
|
| 117 |
+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 118 |
+
thread.start()
|
| 119 |
+
|
| 120 |
+
# 逐个yield生成的文本
|
| 121 |
+
for text_chunk in streamer:
|
| 122 |
+
yield text_chunk
|
| 123 |
+
|
| 124 |
+
thread.join()
|
inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg
ADDED
|
inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg
ADDED
|
inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg
ADDED
|
inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg
ADDED
|