| import os |
| import sys |
|
|
| SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) |
| SOURCE_DIR = os.path.join(SERVER_DIR, "source") |
| if os.path.isdir(SOURCE_DIR): |
| sys.path.insert(0, SOURCE_DIR) |
|
|
| import re |
| import uuid |
| import asyncio |
| from contextlib import asynccontextmanager |
|
|
| import torch |
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from starlette.responses import HTMLResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
| from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams |
|
|
| MODEL_PATH = os.environ.get("MODEL_PATH", "/root/model") |
| HOST = os.environ.get("HOST", "0.0.0.0") |
| PORT = int(os.environ.get("PORT", "8000")) |
| DASHBOARD_DIR = os.path.join(SERVER_DIR, "dashboard") |
|
|
| SYSTEM_PROMPT = ( |
| "You are MedFound, a medical AI assistant. You provide helpful, accurate, " |
| "and evidence-based medical information.\n\n" |
| "Response format rules:\n\n" |
| "1. Start with a clear, direct one-sentence answer to the question.\n\n" |
| "2. Then organize the rest of your response into labeled sections. " |
| "Use section headers like 'Definition:', 'Common uses:', 'Drug class:', " |
| "'Symptoms:', 'Causes:', 'Treatment:', 'Safety:', 'Key points:' etc. " |
| "Put each section header on its own line followed by a newline.\n\n" |
| "3. Under each section, list items one per line using '- ' bullet points.\n\n" |
| "4. Leave a blank line between each section.\n\n" |
| "5. Keep each bullet point short and clear (one idea per bullet).\n\n" |
| "6. At the end, add a 'Safety note:' or 'Important:' section for warnings.\n\n" |
| "7. Stay on topic. Only answer what was asked. " |
| "Do NOT generate unrelated content, fictional patient cases, or diagnosis codes.\n\n" |
| "8. Stop when you have fully answered. Do not keep writing.\n\n" |
| "9. End with a one-line disclaimer that this is for informational purposes only." |
| ) |
|
|
| STOP_PATTERNS = [ |
| "The following is a case", |
| "The patient is a", |
| "Diagnosis code:", |
| "Treatment code:", |
| "The following sections provide", |
| "### User:", |
| "### System:", |
| "## Further reading", |
| "## References", |
| "Further reading", |
| "End of interaction", |
| "System Response", |
| "The following is an example", |
| "This sample response", |
| "Asbury et al", |
| "Lewis & Loftus", |
| ] |
|
|
|
|
| _SECTION_RE = re.compile( |
| r'(?:^|\.\s+|\n\s*)' |
| r'((?:Definition|Common uses?|Drug class|Symptoms?|Causes?|Diagnosis' |
| r'|Treatments?|Safety(?: note)?|Important|Mechanism|Side effects?' |
| r'|Precautions?|Dosage|Key points?|Overview|Risk factors?' |
| r'|Complications?|Prevention|When to see a doctor|Warning' |
| r'|How it works|What it(?:\'s| is) used for|Disclaimer)\s*:\s*)', |
| re.IGNORECASE, |
| ) |
|
|
| _INLINE_LIST_RE = re.compile( |
| r'(?:(?:such as|including|like|e\.g\.|for example|include)\s+)' |
| r'([^.!?]{10,}(?:,\s*(?:and\s+)?[^.!?]+)+)', |
| re.IGNORECASE, |
| ) |
|
|
|
|
| _PREAMBLE_RE = re.compile( |
| r'^(?:[^\n]*\n)*?[^\n]*' |
| r'(?:System Response|sample response|example (?:system )?response|' |
| r'The exact response will vary|Responses? (?:may|should|will) differ)' |
| r'[^\n]*\n+', |
| re.IGNORECASE, |
| ) |
|
|
|
|
| def _truncate_hallucination(text: str) -> str: |
| for pattern in STOP_PATTERNS: |
| idx = text.find(pattern) |
| if idx > len(text) * 0.15: |
| text = text[:idx] |
|
|
| text = _PREAMBLE_RE.sub('', text) |
| return text.strip() |
|
|
|
|
| def _trim_incomplete(text: str) -> str: |
| if not text: |
| return text |
| if not text.endswith((".", "!", "?")): |
| last_end = max(text.rfind("."), text.rfind("!"), text.rfind("?")) |
| if last_end > 0: |
| text = text[:last_end + 1] |
| return text |
|
|
|
|
| def _split_into_sections(text: str) -> list[tuple[str, str]]: |
| """Split text into (header, body) pairs. First section may have empty header.""" |
| parts = _SECTION_RE.split(text) |
| sections: list[tuple[str, str]] = [] |
|
|
| if parts and not _SECTION_RE.match(parts[0].strip() + ":"): |
| first = parts.pop(0).strip() |
| if first: |
| sections.append(("", first)) |
|
|
| i = 0 |
| while i < len(parts): |
| header = parts[i].strip().rstrip(":") |
| body = parts[i + 1].strip() if i + 1 < len(parts) else "" |
| if header and body: |
| sections.append((header, body)) |
| elif header and not body: |
| sections.append((header, "")) |
| i += 2 |
|
|
| return sections |
|
|
|
|
| def _expand_inline_lists(text: str) -> str: |
| """Convert 'such as X, Y, Z, and W' into bullet points.""" |
| def _replacer(m: re.Match) -> str: |
| items_str = m.group(1) |
| items = re.split(r',\s*(?:and\s+)?', items_str) |
| items = [it.strip().rstrip(".").strip() for it in items if it.strip()] |
| if len(items) < 2: |
| return m.group(0) |
| prefix = m.group(0)[:m.start(1) - m.start(0)] |
| bullet_block = "\n".join(f"- {it.capitalize()}" for it in items) |
| return f"{prefix.rstrip()}\n\n{bullet_block}" |
|
|
| return _INLINE_LIST_RE.sub(_replacer, text) |
|
|
|
|
| def _sentences_to_bullets(text: str) -> str: |
| """Convert run-on sentences into bullet points when there are many |
| comma-separated or short-sentence items.""" |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| if len(sentences) < 3: |
| return text |
|
|
| bullets: list[str] = [] |
| for s in sentences: |
| s = s.strip().rstrip(".") |
| if not s: |
| continue |
| parts = s.split(":", 1) |
| if len(parts) == 2 and len(parts[0]) < 40: |
| bullets.append(f"- **{parts[0].strip()}**: {parts[1].strip()}") |
| else: |
| bullets.append(f"- {s}") |
|
|
| return "\n".join(bullets) |
|
|
|
|
| def _fix_numbered_list(text: str) -> str: |
| """Fix broken inline numbered lists (1. X 2. Y) into proper line-separated format.""" |
| if not re.search(r'\d+\.\s', text): |
| return text |
|
|
| items = re.split(r'(?<=[.!?])\s*(?=\d+\.\s)', text) |
| if len(items) < 2: |
| items = re.split(r'\s+(?=\d+\.\s)', text) |
| if len(items) < 2: |
| return text |
|
|
| result: list[str] = [] |
| counter = 0 |
| for item in items: |
| item = item.strip() |
| if not item: |
| continue |
| cleaned = re.sub(r'^\d+\.\s*', '', item) |
| if cleaned != item: |
| counter += 1 |
| result.append(f"{counter}. {cleaned}") |
| else: |
| result.append(item) |
|
|
| return "\n".join(result) |
|
|
|
|
| def format_response(text: str) -> str: |
| """Post-process model output into clean, GPT-style structured formatting.""" |
| text = _truncate_hallucination(text) |
| if not text: |
| return text |
|
|
| text = _trim_incomplete(text) |
| text = re.sub(r'\n{3,}', '\n\n', text) |
| text = re.sub(r'\s{2,}-\s+', '\n- ', text) |
|
|
| if '\n\n' in text and re.search(r'\n- ', text): |
| return text.strip() |
|
|
| sections = _split_into_sections(text) |
|
|
| if len(sections) <= 1 and sections: |
| header, body = sections[0] |
| body = _fix_numbered_list(body) |
|
|
| if re.search(r'\d+\.\s', body) and "\n" in body: |
| return body.strip() |
|
|
| body = _expand_inline_lists(body) |
| if "\n- " in body: |
| return body.strip() |
|
|
| sentences = re.split(r'(?<=[.!?])\s+', body) |
| if len(sentences) >= 4: |
| intro = sentences[0] |
| rest_text = " ".join(sentences[1:]) |
| bullets = _sentences_to_bullets(rest_text) |
| return f"{intro}\n\n{bullets}".strip() |
|
|
| return body.strip() |
|
|
| output_parts: list[str] = [] |
| for header, body in sections: |
| if not header and body: |
| output_parts.append(body) |
| continue |
|
|
| body = _fix_numbered_list(body) |
| body = _expand_inline_lists(body) |
|
|
| if "\n" not in body and not body.startswith("-"): |
| sentences = re.split(r'(?<=[.!?])\s+', body) |
| if len(sentences) >= 3: |
| body = _sentences_to_bullets(body) |
|
|
| section_text = f"**{header}**\n\n{body}" if body else f"**{header}**" |
| output_parts.append(section_text) |
|
|
| result = "\n\n".join(output_parts) |
|
|
| if not re.search( |
| r'(?i)disclaimer|informational purposes|not.{0,20}replace.{0,30}medical advice', |
| result, |
| ): |
| result += ( |
| "\n\n⚠️ *This information is for educational purposes only " |
| "and should not replace professional medical advice.*" |
| ) |
|
|
| return result.strip() |
|
|
|
|
| engine = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global engine |
| print(f"Loading model from {MODEL_PATH} via vLLM ...") |
|
|
| engine_args = AsyncEngineArgs( |
| model=MODEL_PATH, |
| dtype="bfloat16", |
| max_model_len=2048, |
| gpu_memory_utilization=0.40, |
| enforce_eager=True, |
| ) |
| engine = AsyncLLMEngine.from_engine_args(engine_args) |
|
|
| print(f"vLLM engine ready") |
| print(f"Dashboard available at http://{HOST}:{PORT}/") |
| yield |
| engine = None |
| torch.cuda.empty_cache() |
|
|
|
|
| app = FastAPI( |
| title="MedFound Medical Chatbot", |
| description="Medical AI chatbot powered by MedFound-Llama3-8B (vLLM)", |
| version="2.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class Message(BaseModel): |
| role: str = Field(..., pattern="^(user|assistant|system)$") |
| content: str |
|
|
|
|
| class ChatRequest(BaseModel): |
| messages: list[Message] |
| max_new_tokens: int = Field(default=512, ge=1, le=2048) |
| temperature: float = Field(default=0.7, ge=0.01, le=2.0) |
| top_p: float = Field(default=0.9, ge=0.0, le=1.0) |
| stream: bool = False |
|
|
|
|
| class ChatResponse(BaseModel): |
| id: str |
| content: str |
| finish_reason: str |
| usage: dict |
|
|
|
|
| def build_prompt(messages: list[Message]) -> str: |
| parts: list[str] = [] |
|
|
| has_system = any(m.role == "system" for m in messages) |
| if not has_system: |
| parts.append(f"### System: {SYSTEM_PROMPT}\n") |
|
|
| for msg in messages: |
| if msg.role == "system": |
| parts.append(f"### System: {msg.content}\n") |
| elif msg.role == "user": |
| parts.append(f"### User: {msg.content}\n") |
| elif msg.role == "assistant": |
| parts.append(f"### Assistant: {msg.content}\n") |
|
|
| parts.append("### Assistant:") |
| return "\n".join(parts) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "model": MODEL_PATH, |
| "engine": "vLLM", |
| "gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 1) |
| if torch.cuda.is_available() |
| else None, |
| } |
|
|
|
|
| @app.post("/v1/chat", response_model=ChatResponse) |
| async def chat(req: ChatRequest): |
| if not req.messages or req.messages[-1].role != "user": |
| raise HTTPException(400, "Last message must be from the user.") |
|
|
| prompt = build_prompt(req.messages) |
| request_id = f"med-{uuid.uuid4().hex[:12]}" |
|
|
| sampling_params = SamplingParams( |
| max_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| stop=STOP_PATTERNS, |
| repetition_penalty=1.15, |
| ) |
|
|
| if req.stream: |
| return _stream_response(request_id, prompt, sampling_params) |
|
|
| full_text = "" |
| prompt_tokens = 0 |
| completion_tokens = 0 |
|
|
| async for result in engine.generate(prompt, sampling_params, request_id): |
| final = result |
|
|
| output = final.outputs[0] |
| full_text = format_response(output.text) |
| prompt_tokens = len(final.prompt_token_ids) |
| completion_tokens = len(output.token_ids) |
|
|
| return ChatResponse( |
| id=request_id, |
| content=full_text, |
| finish_reason=output.finish_reason or "stop", |
| usage={ |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens, |
| }, |
| ) |
|
|
|
|
| def _stream_response(request_id: str, prompt: str, sampling_params: SamplingParams): |
| async def token_generator(): |
| accumulated = "" |
| sent_len = 0 |
| async for result in engine.generate(prompt, sampling_params, request_id): |
| output = result.outputs[0] |
| accumulated = output.text |
|
|
| formatted = format_response(accumulated) |
| new_text = formatted[sent_len:] |
| if new_text: |
| sent_len = len(formatted) |
| yield f"data: {new_text}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse(token_generator(), media_type="text/event-stream") |
|
|
|
|
| @app.post("/v1/diagnose") |
| async def diagnose(req: ChatRequest): |
| if not req.messages or req.messages[-1].role != "user": |
| raise HTTPException(400, "Last message must be from the user.") |
|
|
| last = req.messages[-1] |
| last.content = ( |
| f"{last.content}\n\n" |
| "Please provide a detailed and comprehensive diagnostic analysis of this medical record." |
| ) |
| return await chat(req) |
|
|
|
|
| @app.get("/", include_in_schema=False) |
| async def serve_dashboard(): |
| dashboard_path = os.path.join(DASHBOARD_DIR, "index.html") |
| with open(dashboard_path, "r") as f: |
| html = f.read() |
| return HTMLResponse(content=html) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run( |
| app, |
| host=HOST, |
| port=PORT, |
| log_level="info", |
| ) |
|
|