File size: 11,593 Bytes
52a881a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from __future__ import annotations

import asyncio
import json
import os
import shutil
import sys
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from queue import Empty, Queue
from threading import Thread
from typing import Optional

import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image

try:
    from .model_utils import (
        DEFAULT_DO_SAMPLE,
        DEFAULT_MAX_NEW_TOKENS,
        DEFAULT_MODEL_PATH,
        DEFAULT_REPETITION_PENALTY,
        QuantizedSkinGPTModel,
    )
except ImportError:
    from model_utils import (
        DEFAULT_DO_SAMPLE,
        DEFAULT_MAX_NEW_TOKENS,
        DEFAULT_MODEL_PATH,
        DEFAULT_REPETITION_PENALTY,
        QuantizedSkinGPTModel,
    )

try:
    from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
except ImportError:
    sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
    from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service

TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")

deepseek_service: Optional[DeepSeekService] = None


def parse_diagnosis_result(raw_text: str) -> dict:
    import re

    think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
    answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)

    thinking = think_match.group(1).strip() if think_match else None
    answer = answer_match.group(1).strip() if answer_match else None

    if not thinking:
        unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
        if unclosed_think:
            thinking = unclosed_think.group(1).strip()

    if not answer:
        unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
        if unclosed_answer:
            answer = unclosed_answer.group(1).strip()

    if not answer:
        cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
        cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
        cleaned = re.sub(r"</?answer>", "", cleaned)
        answer = cleaned.strip() or raw_text

    if answer:
        answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
        final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
        if final_answer_match:
            answer = final_answer_match.group(1).strip()

    if thinking:
        thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()

    return {"thinking": thinking or None, "answer": answer, "raw": raw_text}


print("Initializing INT4 Model Service...")
gpt_model = QuantizedSkinGPTModel(DEFAULT_MODEL_PATH)
print("INT4 service ready.")


async def init_deepseek():
    global deepseek_service
    print("\nInitializing DeepSeek service...")
    deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
    if deepseek_service and deepseek_service.is_loaded:
        print("DeepSeek service is ready!")
    else:
        print("DeepSeek service not available, will return raw results")


@asynccontextmanager
async def lifespan(app: FastAPI):
    await init_deepseek()
    yield
    print("\nShutting down INT4 service...")


app = FastAPI(
    title="SkinGPT-R1 INT4 API",
    description="INT4 quantized dermatology assistant backend",
    version="1.1.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

chat_states = {}
pending_images = {}


@app.post("/v1/upload/{state_id}")
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
    del survey
    try:
        file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
        unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
        file_path = TEMP_DIR / unique_name

        with file_path.open("wb") as buffer:
            shutil.copyfileobj(file.file, buffer)

        pending_images[state_id] = str(file_path)
        if state_id not in chat_states:
            chat_states[state_id] = []

        return {"message": "Image uploaded successfully", "path": str(file_path)}
    except Exception as exc:
        raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc


@app.post("/v1/predict/{state_id}")
async def v1_predict(request: Request, state_id: str):
    try:
        data = await request.json()
    except Exception as exc:
        raise HTTPException(status_code=400, detail="Invalid JSON") from exc

    user_message = data.get("message", "")
    if not user_message:
        raise HTTPException(status_code=400, detail="Missing 'message' field")

    history = chat_states.get(state_id, [])
    current_content = []

    if state_id in pending_images:
        img_path = pending_images.pop(state_id)
        current_content.append({"type": "image", "image": img_path})
        if not history:
            user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"

    current_content.append({"type": "text", "text": user_message})
    history.append({"role": "user", "content": current_content})
    chat_states[state_id] = history

    try:
        response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
    except Exception as exc:
        chat_states[state_id].pop()
        raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc

    history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
    chat_states[state_id] = history
    return {"message": response_text}


@app.post("/v1/reset/{state_id}")
async def reset_chat(state_id: str):
    if state_id in chat_states:
        del chat_states[state_id]
    if state_id in pending_images:
        try:
            Path(pending_images[state_id]).unlink(missing_ok=True)
        except Exception:
            pass
        del pending_images[state_id]
    return {"message": "Chat history reset"}


@app.get("/")
async def root():
    return {
        "name": "SkinGPT-R1 INT4 API",
        "version": "1.1.0",
        "status": "running",
        "description": "INT4 quantized dermatology assistant",
    }


@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": True}


@app.post("/diagnose/stream")
async def diagnose_stream(
    image: Optional[UploadFile] = File(None),
    text: str = Form(...),
    language: str = Form("zh"),
):
    language = language if language in ("zh", "en") else "zh"
    pil_image = None

    if image:
        contents = await image.read()
        pil_image = Image.open(BytesIO(contents)).convert("RGB")

    result_queue = Queue()
    generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}

    def run_generation():
        full_response = []
        try:
            messages = []
            current_content = []
            system_prompt = (
                "You are a professional AI dermatology assistant."
                if language == "en"
                else "你是一个专业的AI皮肤科助手。"
            )

            if pil_image:
                temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
                pil_image.save(temp_image_path)
                generation_result["temp_image_path"] = str(temp_image_path)
                current_content.append({"type": "image", "image": str(temp_image_path)})

            current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
            messages.append({"role": "user", "content": current_content})

            for chunk in gpt_model.generate_response_stream(
                messages=messages,
                max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
                do_sample=DEFAULT_DO_SAMPLE,
                repetition_penalty=DEFAULT_REPETITION_PENALTY,
            ):
                full_response.append(chunk)
                result_queue.put(("delta", chunk))

            response_text = "".join(full_response)
            generation_result["full_response"] = full_response
            generation_result["parsed"] = parse_diagnosis_result(response_text)
            result_queue.put(("generation_done", None))
        except Exception as exc:
            result_queue.put(("error", str(exc)))

    async def event_generator():
        gen_thread = Thread(target=run_generation)
        gen_thread.start()

        loop = asyncio.get_event_loop()
        while True:
            try:
                msg_type, data = await loop.run_in_executor(
                    None,
                    lambda: result_queue.get(timeout=0.1),
                )
                if msg_type == "generation_done":
                    break
                if msg_type == "delta":
                    yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
                elif msg_type == "error":
                    yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
                    gen_thread.join()
                    return
            except Empty:
                await asyncio.sleep(0.01)

        gen_thread.join()
        parsed = generation_result["parsed"]
        if not parsed:
            yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
            return

        raw_thinking = parsed["thinking"]
        raw_answer = parsed["answer"]
        refined_by_deepseek = False
        description = None
        thinking = raw_thinking
        answer = raw_answer

        if deepseek_service and deepseek_service.is_loaded:
            try:
                refined = await deepseek_service.refine_diagnosis(
                    raw_answer=raw_answer,
                    raw_thinking=raw_thinking,
                    language=language,
                )
                if refined["success"]:
                    description = refined["description"]
                    thinking = refined["analysis_process"]
                    answer = refined["diagnosis_result"]
                    refined_by_deepseek = True
            except Exception as exc:
                print(f"DeepSeek refinement failed, using original: {exc}")
        else:
            print("DeepSeek service not available, using raw results")

        final_payload = {
            "description": description,
            "thinking": thinking,
            "answer": answer,
            "raw": parsed["raw"],
            "refined_by_deepseek": refined_by_deepseek,
            "success": True,
            "message": "Diagnosis completed" if language == "en" else "诊断完成",
        }
        yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"

        temp_path = generation_result.get("temp_image_path")
        if temp_path:
            try:
                Path(temp_path).unlink(missing_ok=True)
            except Exception:
                pass

    return StreamingResponse(event_generator(), media_type="text/event-stream")


def main() -> None:
    uvicorn.run("app:app", host="0.0.0.0", port=5901, reload=False)


if __name__ == "__main__":
    main()