File size: 15,051 Bytes
9fb76f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# app.py
import uvicorn
import os
import shutil
import uuid
import json
import re
import asyncio
from typing import Optional
from io import BytesIO
from contextlib import asynccontextmanager
from PIL import Image
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from model_utils import SkinGPTModel
from deepseek_service import get_deepseek_service, DeepSeekService

# === Configuration ===
MODEL_PATH = "../checkpoint"
TEMP_DIR = "./temp_uploads"
os.makedirs(TEMP_DIR, exist_ok=True)

# DeepSeek API Key
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")

# Global DeepSeek service instance
deepseek_service: Optional[DeepSeekService] = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理"""
    # 启动时初始化 DeepSeek 服务
    await init_deepseek()
    yield
    print("\nShutting down service...")

app = FastAPI(
    title="SkinGPT-R1 皮肤诊断系统",
    description="智能皮肤诊断助手",
    version="1.0.0",
    lifespan=lifespan
)

# CORS配置 - 允许前端访问
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 全局变量存储状态
# chat_states: 存储对话历史 (List of messages for Qwen)
# pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
chat_states = {} 
pending_images = {}

def parse_diagnosis_result(raw_text: str) -> dict:
    """
    解析诊断结果中的think和answer标签
    
    参数:
    - raw_text: 原始诊断文本
    
    返回:
    - dict: 包含thinking, answer, raw字段的字典
    """
    import re
    
    # 尝试匹配完整的标签
    think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
    answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
    
    thinking = None
    answer = None
    
    # 处理think标签
    if think_match:
        thinking = think_match.group(1).strip()
    else:
        # 尝试匹配未闭合的think标签(输出被截断的情况)
        unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
        if unclosed_think:
            thinking = unclosed_think.group(1).strip()
    
    # 处理answer标签
    if answer_match:
        answer = answer_match.group(1).strip()
    else:
        # 尝试匹配未闭合的answer标签
        unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
        if unclosed_answer:
            answer = unclosed_answer.group(1).strip()
    
    # 如果仍然没有找到answer,清理原始文本作为answer
    if not answer:
        # 移除所有标签及其内容
        cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
        cleaned = re.sub(r'<think>[\s\S]*', '', cleaned)  # 移除未闭合的think
        cleaned = re.sub(r'</?answer>', '', cleaned)  # 移除answer标签
        cleaned = cleaned.strip()
        answer = cleaned if cleaned else raw_text
    
    # 清理可能残留的标签
    if answer:
        answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
    if thinking:
        thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
    
    # 处理 "Final Answer:" 格式,提取其后的内容
    if answer:
        final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
        if final_answer_match:
            answer = final_answer_match.group(1).strip()
    
    return {
        "thinking": thinking if thinking else None,
        "answer": answer,
        "raw": raw_text
    }

print("Initializing Model Service...")
# 全局加载模型
gpt_model = SkinGPTModel(MODEL_PATH)
print("Service Ready.")

# 初始化 DeepSeek 服务(异步)
async def init_deepseek():
    global deepseek_service
    print("\nInitializing DeepSeek service...")
    deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
    if deepseek_service and deepseek_service.is_loaded:
        print("DeepSeek service is ready!")
    else:
        print("DeepSeek service not available, will return raw results")

@app.post("/v1/upload/{state_id}")
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
    """
    接收图片上传。
    逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
    """
    try:
        # 1. 保存图片到本地临时文件
        file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
        unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
        file_path = os.path.join(TEMP_DIR, unique_name)
        
        with open(file_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        
        # 2. 记录图片路径等待下一次 predict 调用时使用
        # 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
        pending_images[state_id] = file_path
        
        # 3. 初始化对话状态(如果是新会话)
        if state_id not in chat_states:
            chat_states[state_id] = []
            
        return {"message": "Image uploaded successfully", "path": file_path}
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")

@app.post("/v1/predict/{state_id}")
async def v1_predict(request: Request, state_id: str):
    """
    接收文本并执行推理。
    逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
    """
    try:
        data = await request.json()
    except:
        raise HTTPException(status_code=400, detail="Invalid JSON")
        
    user_message = data.get("message", "")
    if not user_message:
        raise HTTPException(status_code=400, detail="Missing 'message' field")

    # 获取或初始化历史
    history = chat_states.get(state_id, [])
    
    # 构建当前轮次的用户内容
    current_content = []
    
    # 1. 检查是否有刚刚上传的图片
    if state_id in pending_images:
        img_path = pending_images.pop(state_id) # 取出并移除
        current_content.append({"type": "image", "image": img_path})
        
        # 如果是第一次对话,加上 System Prompt
        if not history:
             system_prompt = "You are a professional AI dermatology assistant. "
             user_message = f"{system_prompt}\n\n{user_message}"

    # 2. 添加文本
    current_content.append({"type": "text", "text": user_message})
    
    # 3. 更新历史
    history.append({"role": "user", "content": current_content})
    chat_states[state_id] = history

    # 4. 运行推理 (在线程池中运行以防阻塞)
    try:
        response_text = await run_in_threadpool(
            gpt_model.generate_response, 
            messages=history
        )
    except Exception as e:
        # 回滚历史(移除刚才出错的用户提问)
        chat_states[state_id].pop()
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")

    # 5. 将回复加入历史
    history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
    chat_states[state_id] = history

    return {"message": response_text}

@app.post("/v1/reset/{state_id}")
async def reset_chat(state_id: str):
    """清除会话状态"""
    if state_id in chat_states:
        del chat_states[state_id]
    if state_id in pending_images:
        # 可选:删除临时文件
        try:
            os.remove(pending_images[state_id])
        except:
            pass
        del pending_images[state_id]
    return {"message": "Chat history reset"}

@app.get("/")
async def root():
    """根路径"""
    return {
        "name": "SkinGPT-R1 皮肤诊断系统",
        "version": "1.0.0",
        "status": "running",
        "description": "智能皮肤诊断助手"
    }

@app.get("/health")
async def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "model_loaded": True
    }

@app.post("/diagnose/stream")
async def diagnose_stream(
    image: Optional[UploadFile] = File(None),
    text: str = Form(...),
    language: str = Form("zh"),
):
    """
    SSE流式诊断接口(用于前端)
    支持图片上传和文本输入,返回真正的流式响应
    使用 DeepSeek API 优化输出格式
    """
    from queue import Queue, Empty
    from threading import Thread
    
    language = language if language in ("zh", "en") else "zh"
    
    # 处理图片
    pil_image = None
    temp_image_path = None
    
    if image:
        contents = await image.read()
        pil_image = Image.open(BytesIO(contents)).convert("RGB")
    
    # 创建队列用于线程间通信
    result_queue = Queue()
    # 用于存储完整响应和解析结果
    generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
    
    def run_generation():
        """在后台线程中运行流式生成"""
        full_response = []
        
        try:
            # 构建消息
            messages = []
            current_content = []
            
            # 添加系统提示
            system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
            
            # 如果有图片,保存到临时文件
            if pil_image:
                generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
                pil_image.save(generation_result["temp_image_path"])
                current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
            
            # 添加文本
            prompt = f"{system_prompt}\n\n{text}"
            current_content.append({"type": "text", "text": prompt})
            messages.append({"role": "user", "content": current_content})
            
            # 流式生成 - 每个 chunk 立即放入队列
            for chunk in gpt_model.generate_response_stream(
                messages=messages,
                max_new_tokens=2048,
                temperature=0.7
            ):
                full_response.append(chunk)
                result_queue.put(("delta", chunk))
            
            # 解析结果
            response_text = "".join(full_response)
            parsed = parse_diagnosis_result(response_text)
            generation_result["full_response"] = full_response
            generation_result["parsed"] = parsed
            
            # 标记生成完成
            result_queue.put(("generation_done", None))
            
        except Exception as e:
            result_queue.put(("error", str(e)))
    
    async def event_generator():
        """异步生成SSE事件"""
        # 在后台线程启动生成(非阻塞)
        gen_thread = Thread(target=run_generation)
        gen_thread.start()
        
        loop = asyncio.get_event_loop()
        
        # 从队列中读取并发送流式内容
        while True:
            try:
                # 非阻塞获取
                msg_type, data = await loop.run_in_executor(
                    None, 
                    lambda: result_queue.get(timeout=0.1)
                )
                
                if msg_type == "generation_done":
                    # 流式生成完成,准备处理最终结果
                    break
                elif msg_type == "delta":
                    yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
                    yield f"data: {yield_chunk}\n\n"
                elif msg_type == "error":
                    yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
                    gen_thread.join()
                    return
                    
            except Empty:
                # 队列暂时为空,继续等待
                await asyncio.sleep(0.01)
                continue
        
        gen_thread.join()
        
        # 获取解析结果
        parsed = generation_result["parsed"]
        if not parsed:
            yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
            return
        
        raw_thinking = parsed["thinking"]
        raw_answer = parsed["answer"]
        
        # 使用 DeepSeek 优化结果
        refined_by_deepseek = False
        description = None
        thinking = raw_thinking
        answer = raw_answer
        
        if deepseek_service and deepseek_service.is_loaded:
            try:
                print(f"Calling DeepSeek to refine diagnosis (language={language})...")
                refined = await deepseek_service.refine_diagnosis(
                    raw_answer=raw_answer,
                    raw_thinking=raw_thinking,
                    language=language,
                )
                if refined["success"]:
                    description = refined["description"]
                    thinking = refined["analysis_process"]
                    answer = refined["diagnosis_result"]
                    refined_by_deepseek = True
                    print(f"DeepSeek refinement completed successfully")
            except Exception as e:
                print(f"DeepSeek refinement failed, using original: {e}")
        else:
            print("DeepSeek service not available, using raw results")
        
        success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
        
        # 返回格式与参考项目保持一致
        final_payload = {
            "description": description,              # 图片描述(从 thinking 中提取)
            "thinking": thinking,                    # 分析过程(DeepSeek 优化后)
            "answer": answer,                        # 诊断结果(DeepSeek 优化后)
            "raw": parsed["raw"],                    # 原始响应
            "refined_by_deepseek": refined_by_deepseek,  # 是否被 DeepSeek 优化
            "success": True,
            "message": success_msg
        }
        yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
        yield f"data: {yield_final}\n\n"
        
        # 清理临时图片
        temp_path = generation_result.get("temp_image_path")
        if temp_path and os.path.exists(temp_path):
            try:
                os.remove(temp_path)
            except:
                pass
    
    return StreamingResponse(event_generator(), media_type="text/event-stream")

if __name__ == '__main__':
    uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)