yuhos16 commited on
Commit
9fb76f8
·
verified ·
1 Parent(s): dae35d7

Upload folder using huggingface_hub

Browse files
inference/.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import uvicorn
3
+ import os
4
+ import shutil
5
+ import uuid
6
+ import json
7
+ import re
8
+ import asyncio
9
+ from typing import Optional
10
+ from io import BytesIO
11
+ from contextlib import asynccontextmanager
12
+ from PIL import Image
13
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import StreamingResponse
16
+ from fastapi.concurrency import run_in_threadpool
17
+ from model_utils import SkinGPTModel
18
+ from deepseek_service import get_deepseek_service, DeepSeekService
19
+
20
+ # === Configuration ===
21
+ MODEL_PATH = "../checkpoint"
22
+ TEMP_DIR = "./temp_uploads"
23
+ os.makedirs(TEMP_DIR, exist_ok=True)
24
+
25
+ # DeepSeek API Key
26
+ DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
27
+
28
+ # Global DeepSeek service instance
29
+ deepseek_service: Optional[DeepSeekService] = None
30
+
31
+ @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
+ """应用生命周期管理"""
34
+ # 启动时初始化 DeepSeek 服务
35
+ await init_deepseek()
36
+ yield
37
+ print("\nShutting down service...")
38
+
39
+ app = FastAPI(
40
+ title="SkinGPT-R1 皮肤诊断系统",
41
+ description="智能皮肤诊断助手",
42
+ version="1.0.0",
43
+ lifespan=lifespan
44
+ )
45
+
46
+ # CORS配置 - 允许前端访问
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+ # 全局变量存储状态
56
+ # chat_states: 存储对话历史 (List of messages for Qwen)
57
+ # pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
58
+ chat_states = {}
59
+ pending_images = {}
60
+
61
+ def parse_diagnosis_result(raw_text: str) -> dict:
62
+ """
63
+ 解析诊断结果中的think和answer标签
64
+
65
+ 参数:
66
+ - raw_text: 原始诊断文本
67
+
68
+ 返回:
69
+ - dict: 包含thinking, answer, raw字段的字典
70
+ """
71
+ import re
72
+
73
+ # 尝试匹配完整的标签
74
+ think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
75
+ answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
76
+
77
+ thinking = None
78
+ answer = None
79
+
80
+ # 处理think标签
81
+ if think_match:
82
+ thinking = think_match.group(1).strip()
83
+ else:
84
+ # 尝试匹配未闭合的think标签(输出被截断的情况)
85
+ unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
86
+ if unclosed_think:
87
+ thinking = unclosed_think.group(1).strip()
88
+
89
+ # 处理answer标签
90
+ if answer_match:
91
+ answer = answer_match.group(1).strip()
92
+ else:
93
+ # 尝试匹配未闭合的answer标签
94
+ unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
95
+ if unclosed_answer:
96
+ answer = unclosed_answer.group(1).strip()
97
+
98
+ # 如果仍然没有找到answer,清理原始文本作为answer
99
+ if not answer:
100
+ # 移除所有标签及其内容
101
+ cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
102
+ cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
103
+ cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
104
+ cleaned = cleaned.strip()
105
+ answer = cleaned if cleaned else raw_text
106
+
107
+ # 清理可能残留的标签
108
+ if answer:
109
+ answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
110
+ if thinking:
111
+ thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
112
+
113
+ # 处理 "Final Answer:" 格式,提取其后的内容
114
+ if answer:
115
+ final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
116
+ if final_answer_match:
117
+ answer = final_answer_match.group(1).strip()
118
+
119
+ return {
120
+ "thinking": thinking if thinking else None,
121
+ "answer": answer,
122
+ "raw": raw_text
123
+ }
124
+
125
+ print("Initializing Model Service...")
126
+ # 全局加载模型
127
+ gpt_model = SkinGPTModel(MODEL_PATH)
128
+ print("Service Ready.")
129
+
130
+ # 初始化 DeepSeek 服务(异步)
131
+ async def init_deepseek():
132
+ global deepseek_service
133
+ print("\nInitializing DeepSeek service...")
134
+ deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
135
+ if deepseek_service and deepseek_service.is_loaded:
136
+ print("DeepSeek service is ready!")
137
+ else:
138
+ print("DeepSeek service not available, will return raw results")
139
+
140
+ @app.post("/v1/upload/{state_id}")
141
+ async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
142
+ """
143
+ 接收图片上传。
144
+ 逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
145
+ """
146
+ try:
147
+ # 1. 保存图片到本地临时文件
148
+ file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
149
+ unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
150
+ file_path = os.path.join(TEMP_DIR, unique_name)
151
+
152
+ with open(file_path, "wb") as buffer:
153
+ shutil.copyfileobj(file.file, buffer)
154
+
155
+ # 2. 记录图片路径等待下一次 predict 调用时使用
156
+ # 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
157
+ pending_images[state_id] = file_path
158
+
159
+ # 3. 初始化对话状态(如果是新会话)
160
+ if state_id not in chat_states:
161
+ chat_states[state_id] = []
162
+
163
+ return {"message": "Image uploaded successfully", "path": file_path}
164
+
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
167
+
168
+ @app.post("/v1/predict/{state_id}")
169
+ async def v1_predict(request: Request, state_id: str):
170
+ """
171
+ 接收文本并执行推理。
172
+ 逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
173
+ """
174
+ try:
175
+ data = await request.json()
176
+ except:
177
+ raise HTTPException(status_code=400, detail="Invalid JSON")
178
+
179
+ user_message = data.get("message", "")
180
+ if not user_message:
181
+ raise HTTPException(status_code=400, detail="Missing 'message' field")
182
+
183
+ # 获取或初始化历史
184
+ history = chat_states.get(state_id, [])
185
+
186
+ # 构建当前轮次的用户内容
187
+ current_content = []
188
+
189
+ # 1. 检查是否有刚刚上传的图片
190
+ if state_id in pending_images:
191
+ img_path = pending_images.pop(state_id) # 取出并移除
192
+ current_content.append({"type": "image", "image": img_path})
193
+
194
+ # 如果是第一次对话,加上 System Prompt
195
+ if not history:
196
+ system_prompt = "You are a professional AI dermatology assistant. "
197
+ user_message = f"{system_prompt}\n\n{user_message}"
198
+
199
+ # 2. 添加文本
200
+ current_content.append({"type": "text", "text": user_message})
201
+
202
+ # 3. 更新历史
203
+ history.append({"role": "user", "content": current_content})
204
+ chat_states[state_id] = history
205
+
206
+ # 4. 运行推理 (在线程池中运行以防阻塞)
207
+ try:
208
+ response_text = await run_in_threadpool(
209
+ gpt_model.generate_response,
210
+ messages=history
211
+ )
212
+ except Exception as e:
213
+ # 回滚历史(移除刚才出错的用户提问)
214
+ chat_states[state_id].pop()
215
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
216
+
217
+ # 5. 将回复加入历史
218
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
219
+ chat_states[state_id] = history
220
+
221
+ return {"message": response_text}
222
+
223
+ @app.post("/v1/reset/{state_id}")
224
+ async def reset_chat(state_id: str):
225
+ """清除会话状态"""
226
+ if state_id in chat_states:
227
+ del chat_states[state_id]
228
+ if state_id in pending_images:
229
+ # 可选:删除临时文件
230
+ try:
231
+ os.remove(pending_images[state_id])
232
+ except:
233
+ pass
234
+ del pending_images[state_id]
235
+ return {"message": "Chat history reset"}
236
+
237
+ @app.get("/")
238
+ async def root():
239
+ """根路径"""
240
+ return {
241
+ "name": "SkinGPT-R1 皮肤诊断系统",
242
+ "version": "1.0.0",
243
+ "status": "running",
244
+ "description": "智能皮肤诊断助手"
245
+ }
246
+
247
+ @app.get("/health")
248
+ async def health_check():
249
+ """健康检查"""
250
+ return {
251
+ "status": "healthy",
252
+ "model_loaded": True
253
+ }
254
+
255
+ @app.post("/diagnose/stream")
256
+ async def diagnose_stream(
257
+ image: Optional[UploadFile] = File(None),
258
+ text: str = Form(...),
259
+ language: str = Form("zh"),
260
+ ):
261
+ """
262
+ SSE流式诊断接口(用于前端)
263
+ 支持图片上传和文本输入,返回真正的流式响应
264
+ 使用 DeepSeek API 优化输出格式
265
+ """
266
+ from queue import Queue, Empty
267
+ from threading import Thread
268
+
269
+ language = language if language in ("zh", "en") else "zh"
270
+
271
+ # 处理图片
272
+ pil_image = None
273
+ temp_image_path = None
274
+
275
+ if image:
276
+ contents = await image.read()
277
+ pil_image = Image.open(BytesIO(contents)).convert("RGB")
278
+
279
+ # 创建队列用于线程间通信
280
+ result_queue = Queue()
281
+ # 用于存储完整响应和解析结果
282
+ generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
283
+
284
+ def run_generation():
285
+ """在后台线程中运行流式生成"""
286
+ full_response = []
287
+
288
+ try:
289
+ # 构建消息
290
+ messages = []
291
+ current_content = []
292
+
293
+ # 添加系统提示
294
+ system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
295
+
296
+ # 如果有图片,保存到临时文件
297
+ if pil_image:
298
+ generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
299
+ pil_image.save(generation_result["temp_image_path"])
300
+ current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
301
+
302
+ # 添加文本
303
+ prompt = f"{system_prompt}\n\n{text}"
304
+ current_content.append({"type": "text", "text": prompt})
305
+ messages.append({"role": "user", "content": current_content})
306
+
307
+ # 流式生成 - 每个 chunk 立即放入队列
308
+ for chunk in gpt_model.generate_response_stream(
309
+ messages=messages,
310
+ max_new_tokens=2048,
311
+ temperature=0.7
312
+ ):
313
+ full_response.append(chunk)
314
+ result_queue.put(("delta", chunk))
315
+
316
+ # 解析结果
317
+ response_text = "".join(full_response)
318
+ parsed = parse_diagnosis_result(response_text)
319
+ generation_result["full_response"] = full_response
320
+ generation_result["parsed"] = parsed
321
+
322
+ # 标记生成完成
323
+ result_queue.put(("generation_done", None))
324
+
325
+ except Exception as e:
326
+ result_queue.put(("error", str(e)))
327
+
328
+ async def event_generator():
329
+ """异步生成SSE事件"""
330
+ # 在后台线程启动生成(非阻塞)
331
+ gen_thread = Thread(target=run_generation)
332
+ gen_thread.start()
333
+
334
+ loop = asyncio.get_event_loop()
335
+
336
+ # 从队列中读取并发送流式内容
337
+ while True:
338
+ try:
339
+ # 非阻塞获取
340
+ msg_type, data = await loop.run_in_executor(
341
+ None,
342
+ lambda: result_queue.get(timeout=0.1)
343
+ )
344
+
345
+ if msg_type == "generation_done":
346
+ # 流式生成完成,准备处理最终结果
347
+ break
348
+ elif msg_type == "delta":
349
+ yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
350
+ yield f"data: {yield_chunk}\n\n"
351
+ elif msg_type == "error":
352
+ yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
353
+ gen_thread.join()
354
+ return
355
+
356
+ except Empty:
357
+ # 队列暂时为空,继续等待
358
+ await asyncio.sleep(0.01)
359
+ continue
360
+
361
+ gen_thread.join()
362
+
363
+ # 获取解析结果
364
+ parsed = generation_result["parsed"]
365
+ if not parsed:
366
+ yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
367
+ return
368
+
369
+ raw_thinking = parsed["thinking"]
370
+ raw_answer = parsed["answer"]
371
+
372
+ # 使用 DeepSeek 优化结果
373
+ refined_by_deepseek = False
374
+ description = None
375
+ thinking = raw_thinking
376
+ answer = raw_answer
377
+
378
+ if deepseek_service and deepseek_service.is_loaded:
379
+ try:
380
+ print(f"Calling DeepSeek to refine diagnosis (language={language})...")
381
+ refined = await deepseek_service.refine_diagnosis(
382
+ raw_answer=raw_answer,
383
+ raw_thinking=raw_thinking,
384
+ language=language,
385
+ )
386
+ if refined["success"]:
387
+ description = refined["description"]
388
+ thinking = refined["analysis_process"]
389
+ answer = refined["diagnosis_result"]
390
+ refined_by_deepseek = True
391
+ print(f"DeepSeek refinement completed successfully")
392
+ except Exception as e:
393
+ print(f"DeepSeek refinement failed, using original: {e}")
394
+ else:
395
+ print("DeepSeek service not available, using raw results")
396
+
397
+ success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
398
+
399
+ # 返回格式与参考项目保持一致
400
+ final_payload = {
401
+ "description": description, # 图片描述(从 thinking 中提取)
402
+ "thinking": thinking, # 分析过程(DeepSeek 优化后)
403
+ "answer": answer, # 诊断结果(DeepSeek 优化后)
404
+ "raw": parsed["raw"], # 原始响应
405
+ "refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
406
+ "success": True,
407
+ "message": success_msg
408
+ }
409
+ yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
410
+ yield f"data: {yield_final}\n\n"
411
+
412
+ # 清理临时图片
413
+ temp_path = generation_result.get("temp_image_path")
414
+ if temp_path and os.path.exists(temp_path):
415
+ try:
416
+ os.remove(temp_path)
417
+ except:
418
+ pass
419
+
420
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
421
+
422
+ if __name__ == '__main__':
423
+ uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
inference/.ipynb_checkpoints/chat-checkpoint.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ import argparse
3
+ import os
4
+ from model_utils import SkinGPTModel
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
8
+ parser.add_argument("--model_path", type=str, default="../checkpoint")
9
+ parser.add_argument("--image", type=str, required=True, help="Path to initial image")
10
+ args = parser.parse_args()
11
+
12
+ # 初始化模型
13
+ bot = SkinGPTModel(args.model_path)
14
+
15
+ # 初始化对话历史
16
+ # 系统提示词
17
+ system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
18
+
19
+ # 构造第一条包含图片的消息
20
+ if not os.path.exists(args.image):
21
+ print(f"Error: Image {args.image} not found.")
22
+ return
23
+
24
+ history = [
25
+ {
26
+ "role": "user",
27
+ "content": [
28
+ {"type": "image", "image": args.image},
29
+ {"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
30
+ ]
31
+ }
32
+ ]
33
+
34
+ print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
35
+ print(f"Image loaded: {args.image}")
36
+
37
+ # 获取第一轮诊断
38
+ print("\nModel is thinking...", end="", flush=True)
39
+ response = bot.generate_response(history)
40
+ print(f"\rAssistant: {response}\n")
41
+
42
+ # 将助手的回复加入历史
43
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
44
+
45
+ # 进入多轮对话循环
46
+ while True:
47
+ try:
48
+ user_input = input("User: ")
49
+ if user_input.lower() in ["exit", "quit"]:
50
+ break
51
+ if not user_input.strip():
52
+ continue
53
+
54
+ # 加入用户的新问题
55
+ history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
56
+
57
+ print("Model is thinking...", end="", flush=True)
58
+ response = bot.generate_response(history)
59
+ print(f"\rAssistant: {response}\n")
60
+
61
+ # 加入助手的新回复
62
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
63
+
64
+ except KeyboardInterrupt:
65
+ break
66
+
67
+ if __name__ == "__main__":
68
+ main()
inference/.ipynb_checkpoints/deepseek_service-checkpoint.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek API Service
3
+ Used to optimize and organize SkinGPT model output results
4
+ """
5
+
6
+ import os
7
+ import re
8
+ from typing import Optional
9
+ from openai import AsyncOpenAI
10
+
11
+
12
+ class DeepSeekService:
13
+ """DeepSeek API Service Class"""
14
+
15
+ def __init__(self, api_key: Optional[str] = None):
16
+ """
17
+ Initialize DeepSeek service
18
+
19
+ Parameters:
20
+ api_key: DeepSeek API key, reads from environment variable if not provided
21
+ """
22
+ self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
23
+ self.base_url = "https://api.deepseek.com"
24
+ self.model = "deepseek-chat" # Using deepseek-chat model
25
+
26
+ self.client = None
27
+ self.is_loaded = False
28
+
29
+ print(f"DeepSeek API service initializing...")
30
+ print(f"API Base URL: {self.base_url}")
31
+
32
+ async def load(self):
33
+ """Initialize DeepSeek API client"""
34
+ try:
35
+ if not self.api_key:
36
+ print("DeepSeek API key not provided")
37
+ self.is_loaded = False
38
+ return
39
+
40
+ # Initialize OpenAI compatible client
41
+ self.client = AsyncOpenAI(
42
+ api_key=self.api_key,
43
+ base_url=self.base_url
44
+ )
45
+
46
+ self.is_loaded = True
47
+ print("DeepSeek API service is ready!")
48
+
49
+ except Exception as e:
50
+ print(f"DeepSeek API service initialization failed: {e}")
51
+ self.is_loaded = False
52
+
53
+ async def refine_diagnosis(
54
+ self,
55
+ raw_answer: str,
56
+ raw_thinking: Optional[str] = None,
57
+ language: str = "zh"
58
+ ) -> dict:
59
+ """
60
+ Use DeepSeek API to optimize and organize diagnosis results
61
+
62
+ Parameters:
63
+ raw_answer: Original diagnosis result
64
+ raw_thinking: AI thinking process
65
+ language: Language option
66
+
67
+ Returns:
68
+ Dictionary containing "description", "analysis_process" and "diagnosis_result"
69
+ """
70
+
71
+ if not self.is_loaded or self.client is None:
72
+ error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
73
+ print("DeepSeek API not initialized, returning original result")
74
+ return {
75
+ "success": False,
76
+ "description": "",
77
+ "analysis_process": raw_thinking or error_msg,
78
+ "diagnosis_result": raw_answer,
79
+ "original_diagnosis": raw_answer,
80
+ "error": "DeepSeek API not initialized"
81
+ }
82
+
83
+ try:
84
+ # Build prompt
85
+ prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
86
+
87
+ # Select system prompt based on language
88
+ if language == "en":
89
+ system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
90
+ else:
91
+ system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
92
+
93
+ # Call DeepSeek API
94
+ response = await self.client.chat.completions.create(
95
+ model=self.model,
96
+ messages=[
97
+ {"role": "system", "content": system_content},
98
+ {"role": "user", "content": prompt}
99
+ ],
100
+ temperature=0.1,
101
+ max_tokens=2048,
102
+ top_p=0.8,
103
+ )
104
+
105
+ # Extract generated text
106
+ generated_text = response.choices[0].message.content
107
+
108
+ # Parse output
109
+ parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
110
+
111
+ return {
112
+ "success": True,
113
+ "description": parsed["description"],
114
+ "analysis_process": parsed["analysis_process"],
115
+ "diagnosis_result": parsed["diagnosis_result"],
116
+ "original_diagnosis": raw_answer,
117
+ "raw_refined": generated_text
118
+ }
119
+
120
+ except Exception as e:
121
+ print(f"DeepSeek API call failed: {e}")
122
+ error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
123
+ return {
124
+ "success": False,
125
+ "description": "",
126
+ "analysis_process": raw_thinking or error_msg,
127
+ "diagnosis_result": raw_answer,
128
+ "original_diagnosis": raw_answer,
129
+ "error": str(e)
130
+ }
131
+
132
+ def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
133
+ """
134
+ Build optimization prompt
135
+
136
+ Parameters:
137
+ raw_answer: Original diagnosis result
138
+ raw_thinking: AI thinking process
139
+ language: Language option, "zh" for Chinese, "en" for English
140
+
141
+ Returns:
142
+ Built prompt
143
+ """
144
+ if language == "en":
145
+ # English prompt - organize and polish while preserving meaning
146
+ thinking_text = raw_thinking if raw_thinking else "No analysis process available."
147
+ prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
148
+
149
+ 【Requirements】
150
+ - Preserve the original tone and expression style
151
+ - Text 1 contains the thinking process, Text 2 contains the diagnosis result
152
+ - Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
153
+ - For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
154
+ - If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
155
+ - If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
156
+ - DO NOT infer or add new medical information, DO NOT output any meta-commentary
157
+ - You may adjust unreasonable statements or remove redundant content to improve clarity
158
+
159
+ [Text 1]
160
+ {thinking_text}
161
+
162
+ [Text 2]
163
+ {raw_answer}
164
+
165
+ 【Output】Only output three sections, do not output anything else:
166
+ ## Description
167
+ (Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
168
+
169
+ ## Analysis Process
170
+ (Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
171
+
172
+ ## Diagnosis Result
173
+ (The organized diagnosis result from Text 2)
174
+
175
+ 【Example】:
176
+ ## Description
177
+ The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
178
+
179
+ ## Analysis Process
180
+ These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
181
+
182
+ ## Diagnosis Result
183
+ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
184
+ """
185
+ else:
186
+ # Chinese prompt - translate to Simplified Chinese AND organize/polish
187
+ thinking_text = raw_thinking if raw_thinking else "No analysis process available."
188
+ prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
189
+
190
+ 【要求】
191
+ - 保留原文的语气和表达方式
192
+ - 文本1是思考过程,文本2是诊断结果
193
+ - 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
194
+ - 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
195
+ - 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
196
+ - 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
197
+ - 【文本1】和【文本2】需要翻译成简体中文
198
+ - 禁止推断或添加新的医学信息,禁止输出任何元评论
199
+ - 可以调整不合理的语句或去除冗余内容以提高清晰度
200
+
201
+
202
+ 【文本1】
203
+ {thinking_text}
204
+
205
+ 【文本2】
206
+ {raw_answer}
207
+
208
+ 【输出】只输出三个部分,不要输出其他任何内容:
209
+ ## 图像描述
210
+ (从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
211
+
212
+ ## 分析过程
213
+ (提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
214
+
215
+ ## 诊断结果
216
+ (整理后的诊断结果)
217
+
218
+ 【样例】:
219
+ ## 图像描述
220
+ 图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
221
+
222
+ ## 分析过程
223
+ 这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
224
+
225
+ ## 诊断结果
226
+ 可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
227
+ """
228
+
229
+ return prompt
230
+
231
+ def _parse_refined_output(
232
+ self,
233
+ generated_text: str,
234
+ raw_answer: str,
235
+ raw_thinking: Optional[str] = None,
236
+ language: str = "zh"
237
+ ) -> dict:
238
+ """
239
+ Parse DeepSeek generated output
240
+
241
+ Parameters:
242
+ generated_text: DeepSeek generated text
243
+ raw_answer: Original diagnosis (as fallback)
244
+ raw_thinking: Original thinking process (as fallback)
245
+ language: Language option
246
+
247
+ Returns:
248
+ Dictionary containing description, analysis_process and diagnosis_result
249
+ """
250
+ description = ""
251
+ analysis_process = None
252
+ diagnosis_result = None
253
+
254
+ if language == "en":
255
+ # English patterns
256
+ desc_match = re.search(
257
+ r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
258
+ generated_text,
259
+ re.IGNORECASE
260
+ )
261
+ analysis_match = re.search(
262
+ r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
263
+ generated_text,
264
+ re.IGNORECASE
265
+ )
266
+ result_match = re.search(
267
+ r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
268
+ generated_text,
269
+ re.IGNORECASE
270
+ )
271
+
272
+ desc_header = "## Description"
273
+ analysis_header = "## Analysis Process"
274
+ result_header = "## Diagnosis Result"
275
+ else:
276
+ # Chinese patterns
277
+ desc_match = re.search(
278
+ r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
279
+ generated_text
280
+ )
281
+ analysis_match = re.search(
282
+ r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
283
+ generated_text
284
+ )
285
+ result_match = re.search(
286
+ r'##\s*诊断结果\s*\n([\s\S]*?)$',
287
+ generated_text
288
+ )
289
+
290
+ desc_header = "## 图像描述"
291
+ analysis_header = "## 分析过程"
292
+ result_header = "## 诊断结果"
293
+
294
+ # Extract description
295
+ if desc_match:
296
+ description = desc_match.group(1).strip()
297
+ print(f"Successfully parsed description")
298
+ else:
299
+ print(f"Description parsing failed")
300
+ description = ""
301
+
302
+ # Extract analysis process
303
+ if analysis_match:
304
+ analysis_process = analysis_match.group(1).strip()
305
+ print(f"Successfully parsed analysis process")
306
+ else:
307
+ print(f"Analysis process parsing failed, trying other methods")
308
+ # Try to extract from generated text
309
+ result_pos = generated_text.find(result_header)
310
+ if result_pos > 0:
311
+ # Get content before diagnosis result
312
+ analysis_process = generated_text[:result_pos].strip()
313
+ # Remove possible headers
314
+ for header in [desc_header, analysis_header]:
315
+ header_escaped = re.escape(header)
316
+ analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
317
+ else:
318
+ # If no format at all, try to get first half
319
+ mid_point = len(generated_text) // 2
320
+ analysis_process = generated_text[:mid_point].strip()
321
+
322
+ # If still empty, use original content (final fallback)
323
+ if not analysis_process and raw_thinking:
324
+ print(f"Using original raw_thinking as fallback")
325
+ analysis_process = raw_thinking
326
+
327
+ # Extract diagnosis result
328
+ if result_match:
329
+ diagnosis_result = result_match.group(1).strip()
330
+ print(f"Successfully parsed diagnosis result")
331
+ else:
332
+ print(f"Diagnosis result parsing failed, trying other methods")
333
+ # Try to extract from generated text
334
+ result_pos = generated_text.find(result_header)
335
+ if result_pos > 0:
336
+ diagnosis_result = generated_text[result_pos:].strip()
337
+ # Remove possible header
338
+ result_header_escaped = re.escape(result_header)
339
+ diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
340
+ else:
341
+ # If no format at all, get second half
342
+ mid_point = len(generated_text) // 2
343
+ diagnosis_result = generated_text[mid_point:].strip()
344
+
345
+ # If still empty, use original content (final fallback)
346
+ if not diagnosis_result:
347
+ print(f"Using original raw_answer as fallback")
348
+ diagnosis_result = raw_answer
349
+
350
+ return {
351
+ "description": description,
352
+ "analysis_process": analysis_process,
353
+ "diagnosis_result": diagnosis_result
354
+ }
355
+
356
+
357
+ # Global DeepSeek service instance (lazy loading)
358
+ _deepseek_service: Optional[DeepSeekService] = None
359
+
360
+
361
+ async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
362
+ """
363
+ Get DeepSeek service instance (singleton pattern)
364
+
365
+ Parameters:
366
+ api_key: Optional API key to use
367
+
368
+ Returns:
369
+ DeepSeekService instance, or None if API initialization fails
370
+ """
371
+ global _deepseek_service
372
+
373
+ if _deepseek_service is None:
374
+ try:
375
+ _deepseek_service = DeepSeekService(api_key=api_key)
376
+ await _deepseek_service.load()
377
+ if not _deepseek_service.is_loaded:
378
+ print("DeepSeek API service initialization failed, will use fallback mode")
379
+ return _deepseek_service # Return instance but marked as not loaded
380
+ except Exception as e:
381
+ print(f"DeepSeek service initialization failed: {e}")
382
+ return None
383
+
384
+ return _deepseek_service
inference/.ipynb_checkpoints/demo-checkpoint.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ from PIL import Image
5
+
6
+ # === Configuration ===
7
+ MODEL_PATH = "../checkpoint"
8
+ IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
9
+ PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
10
+
11
+ def main():
12
+ print(f"Loading model from {MODEL_PATH}...")
13
+
14
+ # 1. Load Model
15
+ try:
16
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
+ MODEL_PATH,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+ return
26
+
27
+ # 2. Check Image
28
+ import os
29
+ if not os.path.exists(IMAGE_PATH):
30
+ print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
31
+ # Create a dummy image for code demonstration purposes if needed, or just return
32
+ return
33
+
34
+ # 3. Prepare Inputs
35
+ messages = [
36
+ {
37
+ "role": "user",
38
+ "content": [
39
+ {"type": "image", "image": IMAGE_PATH},
40
+ {"type": "text", "text": PROMPT},
41
+ ],
42
+ }
43
+ ]
44
+
45
+ print("Processing...")
46
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47
+ image_inputs, video_inputs = process_vision_info(messages)
48
+
49
+ inputs = processor(
50
+ text=[text],
51
+ images=image_inputs,
52
+ videos=video_inputs,
53
+ padding=True,
54
+ return_tensors="pt",
55
+ ).to(model.device)
56
+
57
+ # 4. Generate
58
+ with torch.no_grad():
59
+ generated_ids = model.generate(
60
+ **inputs,
61
+ max_new_tokens=1024,
62
+ temperature=0.7,
63
+ top_p=0.9
64
+ )
65
+
66
+ # 5. Decode
67
+ output_text = processor.batch_decode(
68
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
+ )
70
+
71
+ print("\n=== Diagnosis Result ===")
72
+ print(output_text[0])
73
+ print("========================")
74
+
75
+ if __name__ == "__main__":
76
+ main()
inference/.ipynb_checkpoints/inference-checkpoint.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ from model_utils import SkinGPTModel
4
+ import os
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
8
+ parser.add_argument("--image", type=str, required=True, help="Path to the image")
9
+ parser.add_argument("--model_path", type=str, default="../checkpoint")
10
+ parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
11
+ args = parser.parse_args()
12
+
13
+ if not os.path.exists(args.image):
14
+ print(f"Error: Image not found at {args.image}")
15
+ return
16
+
17
+ # 1. 加载模型 (复用 model_utils)
18
+ # 这样你就不用在这里重复写 transformers 的加载代码了
19
+ bot = SkinGPTModel(args.model_path)
20
+
21
+ # 2. 构造单轮消息
22
+ system_prompt = "You are a professional AI dermatology assistant."
23
+ messages = [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {"type": "image", "image": args.image},
28
+ {"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
29
+ ]
30
+ }
31
+ ]
32
+
33
+ # 3. 推理
34
+ print(f"\nAnalyzing {args.image}...")
35
+ response = bot.generate_response(messages)
36
+
37
+ print("-" * 40)
38
+ print("Result:")
39
+ print(response)
40
+ print("-" * 40)
41
+
42
+ if __name__ == "__main__":
43
+ main()
inference/.ipynb_checkpoints/model_utils-checkpoint.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
+ from PIL import Image
6
+ import os
7
+ from threading import Thread
8
+
9
+ class SkinGPTModel:
10
+ def __init__(self, model_path, device=None):
11
+ self.model_path = model_path
12
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Loading model from {model_path} on {self.device}...")
14
+
15
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
16
+ model_path,
17
+ torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
18
+ attn_implementation="flash_attention_2" if self.device == "cuda" else None,
19
+ device_map="auto" if self.device != "mps" else None,
20
+ trust_remote_code=True
21
+ )
22
+
23
+ if self.device == "mps":
24
+ self.model = self.model.to(self.device)
25
+
26
+ self.processor = AutoProcessor.from_pretrained(
27
+ model_path,
28
+ trust_remote_code=True,
29
+ min_pixels=256*28*28,
30
+ max_pixels=1280*28*28
31
+ )
32
+ print("Model loaded successfully.")
33
+
34
+ def generate_response(self, messages, max_new_tokens=1024, temperature=0.7):
35
+ """
36
+ 处理多轮对话的历史消息列表并生成回复
37
+ messages format:
38
+ [
39
+ {'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
40
+ {'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
41
+ ]
42
+ """
43
+ # 预处理文本模板
44
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+
46
+ # 预处理视觉信息
47
+ image_inputs, video_inputs = process_vision_info(messages)
48
+
49
+ inputs = self.processor(
50
+ text=[text],
51
+ images=image_inputs,
52
+ videos=video_inputs,
53
+ padding=True,
54
+ return_tensors="pt",
55
+ ).to(self.model.device)
56
+
57
+ with torch.no_grad():
58
+ generated_ids = self.model.generate(
59
+ **inputs,
60
+ max_new_tokens=max_new_tokens,
61
+ temperature=temperature,
62
+ top_p=0.9,
63
+ do_sample=True
64
+ )
65
+
66
+ # 解码输出 (去除输入的token)
67
+ generated_ids_trimmed = [
68
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
69
+ ]
70
+ output_text = self.processor.batch_decode(
71
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
+ )
73
+
74
+ return output_text[0]
75
+
76
+ def generate_response_stream(self, messages, max_new_tokens=2048, temperature=0.7):
77
+ """
78
+ 流式生成响应
79
+ 返回一个生成器,逐个yield生成的文本chunk
80
+ """
81
+ # 预处理文本模板
82
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
83
+
84
+ # 预处理视觉信息
85
+ image_inputs, video_inputs = process_vision_info(messages)
86
+
87
+ inputs = self.processor(
88
+ text=[text],
89
+ images=image_inputs,
90
+ videos=video_inputs,
91
+ padding=True,
92
+ return_tensors="pt",
93
+ ).to(self.model.device)
94
+
95
+ # 创建 TextIteratorStreamer 用于流式输出
96
+ streamer = TextIteratorStreamer(
97
+ self.processor.tokenizer,
98
+ skip_prompt=True,
99
+ skip_special_tokens=True
100
+ )
101
+
102
+ # 准备生成参数
103
+ generation_kwargs = {
104
+ **inputs,
105
+ "max_new_tokens": max_new_tokens,
106
+ "temperature": temperature,
107
+ "top_p": 0.9,
108
+ "do_sample": True,
109
+ "streamer": streamer,
110
+ }
111
+
112
+ # 在单独的线程中运行生成
113
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
114
+ thread.start()
115
+
116
+ # 逐个yield生成的文本
117
+ for text_chunk in streamer:
118
+ yield text_chunk
119
+
120
+ thread.join()
inference/__pycache__/app.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
inference/__pycache__/deepseek_service.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
inference/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (5.39 kB). View file
 
inference/app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import uvicorn
3
+ import os
4
+ import shutil
5
+ import uuid
6
+ import json
7
+ import re
8
+ import asyncio
9
+ from typing import Optional
10
+ from io import BytesIO
11
+ from contextlib import asynccontextmanager
12
+ from PIL import Image
13
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import StreamingResponse
16
+ from fastapi.concurrency import run_in_threadpool
17
+ from model_utils import SkinGPTModel
18
+ from deepseek_service import get_deepseek_service, DeepSeekService
19
+
20
+ # === Configuration ===
21
+ MODEL_PATH = "../checkpoint"
22
+ TEMP_DIR = "./temp_uploads"
23
+ os.makedirs(TEMP_DIR, exist_ok=True)
24
+
25
+ # DeepSeek API Key
26
+ DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
27
+
28
+ # Global DeepSeek service instance
29
+ deepseek_service: Optional[DeepSeekService] = None
30
+
31
+ @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
+ """应用生命周期管理"""
34
+ # 启动时初始化 DeepSeek 服务
35
+ await init_deepseek()
36
+ yield
37
+ print("\nShutting down service...")
38
+
39
+ app = FastAPI(
40
+ title="SkinGPT-R1 皮肤诊断系统",
41
+ description="智能皮肤诊断助手",
42
+ version="1.0.0",
43
+ lifespan=lifespan
44
+ )
45
+
46
+ # CORS配置 - 允许前端访问
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+ # 全局变量存储状态
56
+ # chat_states: 存储对话历史 (List of messages for Qwen)
57
+ # pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
58
+ chat_states = {}
59
+ pending_images = {}
60
+
61
+ def parse_diagnosis_result(raw_text: str) -> dict:
62
+ """
63
+ 解析诊断结果中的think和answer标签
64
+
65
+ 参数:
66
+ - raw_text: 原始诊断文本
67
+
68
+ 返回:
69
+ - dict: 包含thinking, answer, raw字段的字典
70
+ """
71
+ import re
72
+
73
+ # 尝试匹配完整的标签
74
+ think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
75
+ answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
76
+
77
+ thinking = None
78
+ answer = None
79
+
80
+ # 处理think标签
81
+ if think_match:
82
+ thinking = think_match.group(1).strip()
83
+ else:
84
+ # 尝试匹配未闭合的think标签(输出被截断的情况)
85
+ unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
86
+ if unclosed_think:
87
+ thinking = unclosed_think.group(1).strip()
88
+
89
+ # 处理answer标签
90
+ if answer_match:
91
+ answer = answer_match.group(1).strip()
92
+ else:
93
+ # 尝试匹配未闭合的answer标签
94
+ unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
95
+ if unclosed_answer:
96
+ answer = unclosed_answer.group(1).strip()
97
+
98
+ # 如果仍然没有找到answer,清理原始文本作为answer
99
+ if not answer:
100
+ # 移除所有标签及其内容
101
+ cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
102
+ cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
103
+ cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
104
+ cleaned = cleaned.strip()
105
+ answer = cleaned if cleaned else raw_text
106
+
107
+ # 清理可能残留的标签
108
+ if answer:
109
+ answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
110
+ if thinking:
111
+ thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
112
+
113
+ # 处理 "Final Answer:" 格式,提取其后的内容
114
+ if answer:
115
+ final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
116
+ if final_answer_match:
117
+ answer = final_answer_match.group(1).strip()
118
+
119
+ return {
120
+ "thinking": thinking if thinking else None,
121
+ "answer": answer,
122
+ "raw": raw_text
123
+ }
124
+
125
+ print("Initializing Model Service...")
126
+ # 全局加载模型
127
+ gpt_model = SkinGPTModel(MODEL_PATH)
128
+ print("Service Ready.")
129
+
130
+ # 初始化 DeepSeek 服务(异步)
131
+ async def init_deepseek():
132
+ global deepseek_service
133
+ print("\nInitializing DeepSeek service...")
134
+ deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
135
+ if deepseek_service and deepseek_service.is_loaded:
136
+ print("DeepSeek service is ready!")
137
+ else:
138
+ print("DeepSeek service not available, will return raw results")
139
+
140
+ @app.post("/v1/upload/{state_id}")
141
+ async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
142
+ """
143
+ 接收图片上传。
144
+ 逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
145
+ """
146
+ try:
147
+ # 1. 保存图片到本地临时文件
148
+ file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
149
+ unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
150
+ file_path = os.path.join(TEMP_DIR, unique_name)
151
+
152
+ with open(file_path, "wb") as buffer:
153
+ shutil.copyfileobj(file.file, buffer)
154
+
155
+ # 2. 记录图片路径等待下一次 predict 调用时使用
156
+ # 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
157
+ pending_images[state_id] = file_path
158
+
159
+ # 3. 初始化对话状态(如果是新会话)
160
+ if state_id not in chat_states:
161
+ chat_states[state_id] = []
162
+
163
+ return {"message": "Image uploaded successfully", "path": file_path}
164
+
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
167
+
168
+ @app.post("/v1/predict/{state_id}")
169
+ async def v1_predict(request: Request, state_id: str):
170
+ """
171
+ 接收文本并执行推理。
172
+ 逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
173
+ """
174
+ try:
175
+ data = await request.json()
176
+ except:
177
+ raise HTTPException(status_code=400, detail="Invalid JSON")
178
+
179
+ user_message = data.get("message", "")
180
+ if not user_message:
181
+ raise HTTPException(status_code=400, detail="Missing 'message' field")
182
+
183
+ # 获取或初始化历史
184
+ history = chat_states.get(state_id, [])
185
+
186
+ # 构建当前轮次的用户内容
187
+ current_content = []
188
+
189
+ # 1. 检查是否有刚刚上传的图片
190
+ if state_id in pending_images:
191
+ img_path = pending_images.pop(state_id) # 取出并移除
192
+ current_content.append({"type": "image", "image": img_path})
193
+
194
+ # 如果是第一次对话,加上 System Prompt
195
+ if not history:
196
+ system_prompt = "You are a professional AI dermatology assistant. "
197
+ user_message = f"{system_prompt}\n\n{user_message}"
198
+
199
+ # 2. 添加文本
200
+ current_content.append({"type": "text", "text": user_message})
201
+
202
+ # 3. 更新历史
203
+ history.append({"role": "user", "content": current_content})
204
+ chat_states[state_id] = history
205
+
206
+ # 4. 运行推理 (在线程池中运行以防阻塞)
207
+ try:
208
+ response_text = await run_in_threadpool(
209
+ gpt_model.generate_response,
210
+ messages=history
211
+ )
212
+ except Exception as e:
213
+ # 回滚历史(移除刚才出错的用户提问)
214
+ chat_states[state_id].pop()
215
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
216
+
217
+ # 5. 将回复加入历史
218
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
219
+ chat_states[state_id] = history
220
+
221
+ return {"message": response_text}
222
+
223
+ @app.post("/v1/reset/{state_id}")
224
+ async def reset_chat(state_id: str):
225
+ """清除会话状态"""
226
+ if state_id in chat_states:
227
+ del chat_states[state_id]
228
+ if state_id in pending_images:
229
+ # 可选:删除临时文件
230
+ try:
231
+ os.remove(pending_images[state_id])
232
+ except:
233
+ pass
234
+ del pending_images[state_id]
235
+ return {"message": "Chat history reset"}
236
+
237
+ @app.get("/")
238
+ async def root():
239
+ """根路径"""
240
+ return {
241
+ "name": "SkinGPT-R1 皮肤诊断系统",
242
+ "version": "1.0.0",
243
+ "status": "running",
244
+ "description": "智能皮肤诊断助手"
245
+ }
246
+
247
+ @app.get("/health")
248
+ async def health_check():
249
+ """健康检查"""
250
+ return {
251
+ "status": "healthy",
252
+ "model_loaded": True
253
+ }
254
+
255
+ @app.post("/diagnose/stream")
256
+ async def diagnose_stream(
257
+ image: Optional[UploadFile] = File(None),
258
+ text: str = Form(...),
259
+ language: str = Form("zh"),
260
+ ):
261
+ """
262
+ SSE流式诊断接口(用于前端)
263
+ 支持图片上传和文本输入,返回真正的流式响应
264
+ 使用 DeepSeek API 优化输出格式
265
+ """
266
+ from queue import Queue, Empty
267
+ from threading import Thread
268
+
269
+ language = language if language in ("zh", "en") else "zh"
270
+
271
+ # 处理图片
272
+ pil_image = None
273
+ temp_image_path = None
274
+
275
+ if image:
276
+ contents = await image.read()
277
+ pil_image = Image.open(BytesIO(contents)).convert("RGB")
278
+
279
+ # 创建队列用于线程间通信
280
+ result_queue = Queue()
281
+ # 用于存储完整响应和解析结果
282
+ generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
283
+
284
+ def run_generation():
285
+ """在后台线程中运行流式生成"""
286
+ full_response = []
287
+
288
+ try:
289
+ # 构建消息
290
+ messages = []
291
+ current_content = []
292
+
293
+ # 添加系统提示
294
+ system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
295
+
296
+ # 如果有图片,保存到临时文件
297
+ if pil_image:
298
+ generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
299
+ pil_image.save(generation_result["temp_image_path"])
300
+ current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
301
+
302
+ # 添加文本
303
+ prompt = f"{system_prompt}\n\n{text}"
304
+ current_content.append({"type": "text", "text": prompt})
305
+ messages.append({"role": "user", "content": current_content})
306
+
307
+ # 流式生成 - 每个 chunk 立即放入队列
308
+ for chunk in gpt_model.generate_response_stream(
309
+ messages=messages,
310
+ max_new_tokens=2048,
311
+ temperature=0.7
312
+ ):
313
+ full_response.append(chunk)
314
+ result_queue.put(("delta", chunk))
315
+
316
+ # 解析结果
317
+ response_text = "".join(full_response)
318
+ parsed = parse_diagnosis_result(response_text)
319
+ generation_result["full_response"] = full_response
320
+ generation_result["parsed"] = parsed
321
+
322
+ # 标记生成完成
323
+ result_queue.put(("generation_done", None))
324
+
325
+ except Exception as e:
326
+ result_queue.put(("error", str(e)))
327
+
328
+ async def event_generator():
329
+ """异步生成SSE事件"""
330
+ # 在后台线程启动生成(非阻塞)
331
+ gen_thread = Thread(target=run_generation)
332
+ gen_thread.start()
333
+
334
+ loop = asyncio.get_event_loop()
335
+
336
+ # 从队列中读取并发送流式内容
337
+ while True:
338
+ try:
339
+ # 非阻塞获取
340
+ msg_type, data = await loop.run_in_executor(
341
+ None,
342
+ lambda: result_queue.get(timeout=0.1)
343
+ )
344
+
345
+ if msg_type == "generation_done":
346
+ # 流式生成完成,准备处理最终结果
347
+ break
348
+ elif msg_type == "delta":
349
+ yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
350
+ yield f"data: {yield_chunk}\n\n"
351
+ elif msg_type == "error":
352
+ yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
353
+ gen_thread.join()
354
+ return
355
+
356
+ except Empty:
357
+ # 队列暂时为空,继续等待
358
+ await asyncio.sleep(0.01)
359
+ continue
360
+
361
+ gen_thread.join()
362
+
363
+ # 获取解析结果
364
+ parsed = generation_result["parsed"]
365
+ if not parsed:
366
+ yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
367
+ return
368
+
369
+ raw_thinking = parsed["thinking"]
370
+ raw_answer = parsed["answer"]
371
+
372
+ # 使用 DeepSeek 优化结果
373
+ refined_by_deepseek = False
374
+ description = None
375
+ thinking = raw_thinking
376
+ answer = raw_answer
377
+
378
+ if deepseek_service and deepseek_service.is_loaded:
379
+ try:
380
+ print(f"Calling DeepSeek to refine diagnosis (language={language})...")
381
+ refined = await deepseek_service.refine_diagnosis(
382
+ raw_answer=raw_answer,
383
+ raw_thinking=raw_thinking,
384
+ language=language,
385
+ )
386
+ if refined["success"]:
387
+ description = refined["description"]
388
+ thinking = refined["analysis_process"]
389
+ answer = refined["diagnosis_result"]
390
+ refined_by_deepseek = True
391
+ print(f"DeepSeek refinement completed successfully")
392
+ except Exception as e:
393
+ print(f"DeepSeek refinement failed, using original: {e}")
394
+ else:
395
+ print("DeepSeek service not available, using raw results")
396
+
397
+ success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
398
+
399
+ # 返回格式与参考项目保持一致
400
+ final_payload = {
401
+ "description": description, # 图片描述(从 thinking 中提取)
402
+ "thinking": thinking, # 分析过程(DeepSeek 优化后)
403
+ "answer": answer, # 诊断结果(DeepSeek 优化后)
404
+ "raw": parsed["raw"], # 原始响应
405
+ "refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
406
+ "success": True,
407
+ "message": success_msg
408
+ }
409
+ yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
410
+ yield f"data: {yield_final}\n\n"
411
+
412
+ # 清理临时图片
413
+ temp_path = generation_result.get("temp_image_path")
414
+ if temp_path and os.path.exists(temp_path):
415
+ try:
416
+ os.remove(temp_path)
417
+ except:
418
+ pass
419
+
420
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
421
+
422
+ if __name__ == '__main__':
423
+ uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
inference/chat.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ import argparse
3
+ import os
4
+ from model_utils import SkinGPTModel
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
8
+ parser.add_argument("--model_path", type=str, default="../checkpoint")
9
+ parser.add_argument("--image", type=str, required=True, help="Path to initial image")
10
+ args = parser.parse_args()
11
+
12
+ # 初始化模型
13
+ bot = SkinGPTModel(args.model_path)
14
+
15
+ # 初始化对话历史
16
+ # 系统提示词
17
+ system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
18
+
19
+ # 构造第一条包含图片的消息
20
+ if not os.path.exists(args.image):
21
+ print(f"Error: Image {args.image} not found.")
22
+ return
23
+
24
+ history = [
25
+ {
26
+ "role": "user",
27
+ "content": [
28
+ {"type": "image", "image": args.image},
29
+ {"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
30
+ ]
31
+ }
32
+ ]
33
+
34
+ print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
35
+ print(f"Image loaded: {args.image}")
36
+
37
+ # 获取第一轮诊断
38
+ print("\nModel is thinking...", end="", flush=True)
39
+ response = bot.generate_response(history)
40
+ print(f"\rAssistant: {response}\n")
41
+
42
+ # 将助手的回复加入历史
43
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
44
+
45
+ # 进入多轮对话循环
46
+ while True:
47
+ try:
48
+ user_input = input("User: ")
49
+ if user_input.lower() in ["exit", "quit"]:
50
+ break
51
+ if not user_input.strip():
52
+ continue
53
+
54
+ # 加入用户的新问题
55
+ history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
56
+
57
+ print("Model is thinking...", end="", flush=True)
58
+ response = bot.generate_response(history)
59
+ print(f"\rAssistant: {response}\n")
60
+
61
+ # 加入助手的新回复
62
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
63
+
64
+ except KeyboardInterrupt:
65
+ break
66
+
67
+ if __name__ == "__main__":
68
+ main()
inference/deepseek_service.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek API Service
3
+ Used to optimize and organize SkinGPT model output results
4
+ """
5
+
6
+ import os
7
+ import re
8
+ from typing import Optional
9
+ from openai import AsyncOpenAI
10
+
11
+
12
+ class DeepSeekService:
13
+ """DeepSeek API Service Class"""
14
+
15
+ def __init__(self, api_key: Optional[str] = None):
16
+ """
17
+ Initialize DeepSeek service
18
+
19
+ Parameters:
20
+ api_key: DeepSeek API key, reads from environment variable if not provided
21
+ """
22
+ self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
23
+ self.base_url = "https://api.deepseek.com"
24
+ self.model = "deepseek-chat" # Using deepseek-chat model
25
+
26
+ self.client = None
27
+ self.is_loaded = False
28
+
29
+ print(f"DeepSeek API service initializing...")
30
+ print(f"API Base URL: {self.base_url}")
31
+
32
+ async def load(self):
33
+ """Initialize DeepSeek API client"""
34
+ try:
35
+ if not self.api_key:
36
+ print("DeepSeek API key not provided")
37
+ self.is_loaded = False
38
+ return
39
+
40
+ # Initialize OpenAI compatible client
41
+ self.client = AsyncOpenAI(
42
+ api_key=self.api_key,
43
+ base_url=self.base_url
44
+ )
45
+
46
+ self.is_loaded = True
47
+ print("DeepSeek API service is ready!")
48
+
49
+ except Exception as e:
50
+ print(f"DeepSeek API service initialization failed: {e}")
51
+ self.is_loaded = False
52
+
53
+ async def refine_diagnosis(
54
+ self,
55
+ raw_answer: str,
56
+ raw_thinking: Optional[str] = None,
57
+ language: str = "zh"
58
+ ) -> dict:
59
+ """
60
+ Use DeepSeek API to optimize and organize diagnosis results
61
+
62
+ Parameters:
63
+ raw_answer: Original diagnosis result
64
+ raw_thinking: AI thinking process
65
+ language: Language option
66
+
67
+ Returns:
68
+ Dictionary containing "description", "analysis_process" and "diagnosis_result"
69
+ """
70
+
71
+ if not self.is_loaded or self.client is None:
72
+ error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
73
+ print("DeepSeek API not initialized, returning original result")
74
+ return {
75
+ "success": False,
76
+ "description": "",
77
+ "analysis_process": raw_thinking or error_msg,
78
+ "diagnosis_result": raw_answer,
79
+ "original_diagnosis": raw_answer,
80
+ "error": "DeepSeek API not initialized"
81
+ }
82
+
83
+ try:
84
+ # Build prompt
85
+ prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
86
+
87
+ # Select system prompt based on language
88
+ if language == "en":
89
+ system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
90
+ else:
91
+ system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
92
+
93
+ # Call DeepSeek API
94
+ response = await self.client.chat.completions.create(
95
+ model=self.model,
96
+ messages=[
97
+ {"role": "system", "content": system_content},
98
+ {"role": "user", "content": prompt}
99
+ ],
100
+ temperature=0.1,
101
+ max_tokens=2048,
102
+ top_p=0.8,
103
+ )
104
+
105
+ # Extract generated text
106
+ generated_text = response.choices[0].message.content
107
+
108
+ # Parse output
109
+ parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
110
+
111
+ return {
112
+ "success": True,
113
+ "description": parsed["description"],
114
+ "analysis_process": parsed["analysis_process"],
115
+ "diagnosis_result": parsed["diagnosis_result"],
116
+ "original_diagnosis": raw_answer,
117
+ "raw_refined": generated_text
118
+ }
119
+
120
+ except Exception as e:
121
+ print(f"DeepSeek API call failed: {e}")
122
+ error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
123
+ return {
124
+ "success": False,
125
+ "description": "",
126
+ "analysis_process": raw_thinking or error_msg,
127
+ "diagnosis_result": raw_answer,
128
+ "original_diagnosis": raw_answer,
129
+ "error": str(e)
130
+ }
131
+
132
+ def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
133
+ """
134
+ Build optimization prompt
135
+
136
+ Parameters:
137
+ raw_answer: Original diagnosis result
138
+ raw_thinking: AI thinking process
139
+ language: Language option, "zh" for Chinese, "en" for English
140
+
141
+ Returns:
142
+ Built prompt
143
+ """
144
+ if language == "en":
145
+ # English prompt - organize and polish while preserving meaning
146
+ thinking_text = raw_thinking if raw_thinking else "No analysis process available."
147
+ prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
148
+
149
+ 【Requirements】
150
+ - Preserve the original tone and expression style
151
+ - Text 1 contains the thinking process, Text 2 contains the diagnosis result
152
+ - Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
153
+ - For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
154
+ - If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
155
+ - If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
156
+ - DO NOT infer or add new medical information, DO NOT output any meta-commentary
157
+ - You may adjust unreasonable statements or remove redundant content to improve clarity
158
+
159
+ [Text 1]
160
+ {thinking_text}
161
+
162
+ [Text 2]
163
+ {raw_answer}
164
+
165
+ 【Output】Only output three sections, do not output anything else:
166
+ ## Description
167
+ (Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
168
+
169
+ ## Analysis Process
170
+ (Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
171
+
172
+ ## Diagnosis Result
173
+ (The organized diagnosis result from Text 2)
174
+
175
+ 【Example】:
176
+ ## Description
177
+ The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
178
+
179
+ ## Analysis Process
180
+ These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
181
+
182
+ ## Diagnosis Result
183
+ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
184
+ """
185
+ else:
186
+ # Chinese prompt - translate to Simplified Chinese AND organize/polish
187
+ thinking_text = raw_thinking if raw_thinking else "No analysis process available."
188
+ prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
189
+
190
+ 【要求】
191
+ - 保留原文的语气和表达方式
192
+ - 文本1是思考过程,文本2是诊断结果
193
+ - 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
194
+ - 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
195
+ - 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
196
+ - 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
197
+ - 【文本1】和【文本2】需要翻译成简体中文
198
+ - 禁止推断或添加新的医学信息,禁止输出任何元评论
199
+ - 可以调整不合理的语句或去除冗余内容以提高清晰度
200
+
201
+
202
+ 【文本1】
203
+ {thinking_text}
204
+
205
+ 【文本2】
206
+ {raw_answer}
207
+
208
+ 【输出】只输出三个部分,不要输出其他任何内容:
209
+ ## 图像描述
210
+ (从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
211
+
212
+ ## 分析过程
213
+ (提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
214
+
215
+ ## 诊断结果
216
+ (整理后的诊断结果)
217
+
218
+ 【样例】:
219
+ ## 图像描述
220
+ 图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
221
+
222
+ ## 分析过程
223
+ 这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
224
+
225
+ ## 诊断结果
226
+ 可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
227
+ """
228
+
229
+ return prompt
230
+
231
+ def _parse_refined_output(
232
+ self,
233
+ generated_text: str,
234
+ raw_answer: str,
235
+ raw_thinking: Optional[str] = None,
236
+ language: str = "zh"
237
+ ) -> dict:
238
+ """
239
+ Parse DeepSeek generated output
240
+
241
+ Parameters:
242
+ generated_text: DeepSeek generated text
243
+ raw_answer: Original diagnosis (as fallback)
244
+ raw_thinking: Original thinking process (as fallback)
245
+ language: Language option
246
+
247
+ Returns:
248
+ Dictionary containing description, analysis_process and diagnosis_result
249
+ """
250
+ description = ""
251
+ analysis_process = None
252
+ diagnosis_result = None
253
+
254
+ if language == "en":
255
+ # English patterns
256
+ desc_match = re.search(
257
+ r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
258
+ generated_text,
259
+ re.IGNORECASE
260
+ )
261
+ analysis_match = re.search(
262
+ r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
263
+ generated_text,
264
+ re.IGNORECASE
265
+ )
266
+ result_match = re.search(
267
+ r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
268
+ generated_text,
269
+ re.IGNORECASE
270
+ )
271
+
272
+ desc_header = "## Description"
273
+ analysis_header = "## Analysis Process"
274
+ result_header = "## Diagnosis Result"
275
+ else:
276
+ # Chinese patterns
277
+ desc_match = re.search(
278
+ r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
279
+ generated_text
280
+ )
281
+ analysis_match = re.search(
282
+ r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
283
+ generated_text
284
+ )
285
+ result_match = re.search(
286
+ r'##\s*诊断结果\s*\n([\s\S]*?)$',
287
+ generated_text
288
+ )
289
+
290
+ desc_header = "## 图像描述"
291
+ analysis_header = "## 分析过程"
292
+ result_header = "## 诊断结果"
293
+
294
+ # Extract description
295
+ if desc_match:
296
+ description = desc_match.group(1).strip()
297
+ print(f"Successfully parsed description")
298
+ else:
299
+ print(f"Description parsing failed")
300
+ description = ""
301
+
302
+ # Extract analysis process
303
+ if analysis_match:
304
+ analysis_process = analysis_match.group(1).strip()
305
+ print(f"Successfully parsed analysis process")
306
+ else:
307
+ print(f"Analysis process parsing failed, trying other methods")
308
+ # Try to extract from generated text
309
+ result_pos = generated_text.find(result_header)
310
+ if result_pos > 0:
311
+ # Get content before diagnosis result
312
+ analysis_process = generated_text[:result_pos].strip()
313
+ # Remove possible headers
314
+ for header in [desc_header, analysis_header]:
315
+ header_escaped = re.escape(header)
316
+ analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
317
+ else:
318
+ # If no format at all, try to get first half
319
+ mid_point = len(generated_text) // 2
320
+ analysis_process = generated_text[:mid_point].strip()
321
+
322
+ # If still empty, use original content (final fallback)
323
+ if not analysis_process and raw_thinking:
324
+ print(f"Using original raw_thinking as fallback")
325
+ analysis_process = raw_thinking
326
+
327
+ # Extract diagnosis result
328
+ if result_match:
329
+ diagnosis_result = result_match.group(1).strip()
330
+ print(f"Successfully parsed diagnosis result")
331
+ else:
332
+ print(f"Diagnosis result parsing failed, trying other methods")
333
+ # Try to extract from generated text
334
+ result_pos = generated_text.find(result_header)
335
+ if result_pos > 0:
336
+ diagnosis_result = generated_text[result_pos:].strip()
337
+ # Remove possible header
338
+ result_header_escaped = re.escape(result_header)
339
+ diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
340
+ else:
341
+ # If no format at all, get second half
342
+ mid_point = len(generated_text) // 2
343
+ diagnosis_result = generated_text[mid_point:].strip()
344
+
345
+ # If still empty, use original content (final fallback)
346
+ if not diagnosis_result:
347
+ print(f"Using original raw_answer as fallback")
348
+ diagnosis_result = raw_answer
349
+
350
+ return {
351
+ "description": description,
352
+ "analysis_process": analysis_process,
353
+ "diagnosis_result": diagnosis_result
354
+ }
355
+
356
+
357
+ # Global DeepSeek service instance (lazy loading)
358
+ _deepseek_service: Optional[DeepSeekService] = None
359
+
360
+
361
+ async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
362
+ """
363
+ Get DeepSeek service instance (singleton pattern)
364
+
365
+ Parameters:
366
+ api_key: Optional API key to use
367
+
368
+ Returns:
369
+ DeepSeekService instance, or None if API initialization fails
370
+ """
371
+ global _deepseek_service
372
+
373
+ if _deepseek_service is None:
374
+ try:
375
+ _deepseek_service = DeepSeekService(api_key=api_key)
376
+ await _deepseek_service.load()
377
+ if not _deepseek_service.is_loaded:
378
+ print("DeepSeek API service initialization failed, will use fallback mode")
379
+ return _deepseek_service # Return instance but marked as not loaded
380
+ except Exception as e:
381
+ print(f"DeepSeek service initialization failed: {e}")
382
+ return None
383
+
384
+ return _deepseek_service
inference/demo.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ from PIL import Image
5
+
6
+ # === Configuration ===
7
+ MODEL_PATH = "../checkpoint"
8
+ IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
9
+ PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
10
+
11
+ def main():
12
+ print(f"Loading model from {MODEL_PATH}...")
13
+
14
+ # 1. Load Model
15
+ try:
16
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
+ MODEL_PATH,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+ return
26
+
27
+ # 2. Check Image
28
+ import os
29
+ if not os.path.exists(IMAGE_PATH):
30
+ print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
31
+ # Create a dummy image for code demonstration purposes if needed, or just return
32
+ return
33
+
34
+ # 3. Prepare Inputs
35
+ messages = [
36
+ {
37
+ "role": "user",
38
+ "content": [
39
+ {"type": "image", "image": IMAGE_PATH},
40
+ {"type": "text", "text": PROMPT},
41
+ ],
42
+ }
43
+ ]
44
+
45
+ print("Processing...")
46
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47
+ image_inputs, video_inputs = process_vision_info(messages)
48
+
49
+ inputs = processor(
50
+ text=[text],
51
+ images=image_inputs,
52
+ videos=video_inputs,
53
+ padding=True,
54
+ return_tensors="pt",
55
+ ).to(model.device)
56
+
57
+ # 4. Generate
58
+ with torch.no_grad():
59
+ generated_ids = model.generate(
60
+ **inputs,
61
+ max_new_tokens=1024,
62
+ temperature=0.7,
63
+ repetition_penalty=1.2,
64
+ no_repeat_ngram_size=3,
65
+ top_p=0.9,
66
+ do_sample=True
67
+ )
68
+
69
+ # 5. Decode
70
+ output_text = processor.batch_decode(
71
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
+ )
73
+
74
+ print("\n=== Diagnosis Result ===")
75
+ print(output_text[0])
76
+ print("========================")
77
+
78
+ if __name__ == "__main__":
79
+ main()
inference/inference.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ from model_utils import SkinGPTModel
4
+ import os
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
8
+ parser.add_argument("--image", type=str, required=True, help="Path to the image")
9
+ parser.add_argument("--model_path", type=str, default="../checkpoint")
10
+ parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
11
+ args = parser.parse_args()
12
+
13
+ if not os.path.exists(args.image):
14
+ print(f"Error: Image not found at {args.image}")
15
+ return
16
+
17
+ # 1. 加载模型 (复用 model_utils)
18
+ # 这样你就不用在这里重复写 transformers 的加载代码了
19
+ bot = SkinGPTModel(args.model_path)
20
+
21
+ # 2. 构造单轮消息
22
+ system_prompt = "You are a professional AI dermatology assistant."
23
+ messages = [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {"type": "image", "image": args.image},
28
+ {"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
29
+ ]
30
+ }
31
+ ]
32
+
33
+ # 3. 推理
34
+ print(f"\nAnalyzing {args.image}...")
35
+ response = bot.generate_response(messages)
36
+
37
+ print("-" * 40)
38
+ print("Result:")
39
+ print(response)
40
+ print("-" * 40)
41
+
42
+ if __name__ == "__main__":
43
+ main()
inference/model_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
+ from PIL import Image
6
+ import os
7
+ from threading import Thread
8
+
9
+ class SkinGPTModel:
10
+ def __init__(self, model_path, device=None):
11
+ self.model_path = model_path
12
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Loading model from {model_path} on {self.device}...")
14
+
15
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
16
+ model_path,
17
+ torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
18
+ attn_implementation="flash_attention_2" if self.device == "cuda" else None,
19
+ device_map="auto" if self.device != "mps" else None,
20
+ trust_remote_code=True
21
+ )
22
+
23
+ if self.device == "mps":
24
+ self.model = self.model.to(self.device)
25
+
26
+ self.processor = AutoProcessor.from_pretrained(
27
+ model_path,
28
+ trust_remote_code=True,
29
+ min_pixels=256*28*28,
30
+ max_pixels=1280*28*28
31
+ )
32
+ print("Model loaded successfully.")
33
+
34
+ def generate_response(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
35
+ """
36
+ 处理多轮对话的历史消息列表并生成回复
37
+ messages format:
38
+ [
39
+ {'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
40
+ {'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
41
+ ]
42
+ """
43
+ # 预处理文本模板
44
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+
46
+ # 预处理视觉信息
47
+ image_inputs, video_inputs = process_vision_info(messages)
48
+
49
+ inputs = self.processor(
50
+ text=[text],
51
+ images=image_inputs,
52
+ videos=video_inputs,
53
+ padding=True,
54
+ return_tensors="pt",
55
+ ).to(self.model.device)
56
+
57
+ with torch.no_grad():
58
+ generated_ids = self.model.generate(
59
+ **inputs,
60
+ max_new_tokens=max_new_tokens,
61
+ temperature=temperature,
62
+ repetition_penalty=repetition_penalty,
63
+ no_repeat_ngram_size=no_repeat_ngram_size,
64
+ top_p=0.9,
65
+ do_sample=True
66
+ )
67
+
68
+ # 解码输出 (去除输入的token)
69
+ generated_ids_trimmed = [
70
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
71
+ ]
72
+ output_text = self.processor.batch_decode(
73
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
74
+ )
75
+
76
+ return output_text[0]
77
+
78
+ def generate_response_stream(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
79
+ """
80
+ 流式生成响应
81
+ 返回一个生成器,逐个yield生成的文本chunk
82
+ """
83
+ # 预处理文本模板
84
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
85
+
86
+ # 预处理视觉信息
87
+ image_inputs, video_inputs = process_vision_info(messages)
88
+
89
+ inputs = self.processor(
90
+ text=[text],
91
+ images=image_inputs,
92
+ videos=video_inputs,
93
+ padding=True,
94
+ return_tensors="pt",
95
+ ).to(self.model.device)
96
+
97
+ # 创建 TextIteratorStreamer 用于流式输出
98
+ streamer = TextIteratorStreamer(
99
+ self.processor.tokenizer,
100
+ skip_prompt=True,
101
+ skip_special_tokens=True
102
+ )
103
+
104
+ # 准备生成参数
105
+ generation_kwargs = {
106
+ **inputs,
107
+ "max_new_tokens": max_new_tokens,
108
+ "temperature": temperature,
109
+ "repetition_penalty": repetition_penalty,
110
+ "no_repeat_ngram_size": no_repeat_ngram_size,
111
+ "top_p": 0.9,
112
+ "do_sample": True,
113
+ "streamer": streamer,
114
+ }
115
+
116
+ # 在单独的线程中运行生成
117
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
118
+ thread.start()
119
+
120
+ # 逐个yield生成的文本
121
+ for text_chunk in streamer:
122
+ yield text_chunk
123
+
124
+ thread.join()
inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg ADDED
inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg ADDED
inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg ADDED
inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg ADDED