import os import sys SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) SOURCE_DIR = os.path.join(SERVER_DIR, "source") if os.path.isdir(SOURCE_DIR): sys.path.insert(0, SOURCE_DIR) import re import uuid import asyncio from contextlib import asynccontextmanager import torch import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from starlette.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel, Field from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams MODEL_PATH = os.environ.get("MODEL_PATH", "/root/model") HOST = os.environ.get("HOST", "0.0.0.0") PORT = int(os.environ.get("PORT", "8000")) DASHBOARD_DIR = os.path.join(SERVER_DIR, "dashboard") SYSTEM_PROMPT = ( "You are MedFound, a medical AI assistant. You provide helpful, accurate, " "and evidence-based medical information.\n\n" "Response format rules:\n\n" "1. Start with a clear, direct one-sentence answer to the question.\n\n" "2. Then organize the rest of your response into labeled sections. " "Use section headers like 'Definition:', 'Common uses:', 'Drug class:', " "'Symptoms:', 'Causes:', 'Treatment:', 'Safety:', 'Key points:' etc. " "Put each section header on its own line followed by a newline.\n\n" "3. Under each section, list items one per line using '- ' bullet points.\n\n" "4. Leave a blank line between each section.\n\n" "5. Keep each bullet point short and clear (one idea per bullet).\n\n" "6. At the end, add a 'Safety note:' or 'Important:' section for warnings.\n\n" "7. Stay on topic. Only answer what was asked. " "Do NOT generate unrelated content, fictional patient cases, or diagnosis codes.\n\n" "8. Stop when you have fully answered. Do not keep writing.\n\n" "9. End with a one-line disclaimer that this is for informational purposes only." ) STOP_PATTERNS = [ "The following is a case", "The patient is a", "Diagnosis code:", "Treatment code:", "The following sections provide", "### User:", "### System:", "## Further reading", "## References", "Further reading", "End of interaction", "System Response", "The following is an example", "This sample response", "Asbury et al", "Lewis & Loftus", ] _SECTION_RE = re.compile( r'(?:^|\.\s+|\n\s*)' r'((?:Definition|Common uses?|Drug class|Symptoms?|Causes?|Diagnosis' r'|Treatments?|Safety(?: note)?|Important|Mechanism|Side effects?' r'|Precautions?|Dosage|Key points?|Overview|Risk factors?' r'|Complications?|Prevention|When to see a doctor|Warning' r'|How it works|What it(?:\'s| is) used for|Disclaimer)\s*:\s*)', re.IGNORECASE, ) _INLINE_LIST_RE = re.compile( r'(?:(?:such as|including|like|e\.g\.|for example|include)\s+)' r'([^.!?]{10,}(?:,\s*(?:and\s+)?[^.!?]+)+)', re.IGNORECASE, ) _PREAMBLE_RE = re.compile( r'^(?:[^\n]*\n)*?[^\n]*' r'(?:System Response|sample response|example (?:system )?response|' r'The exact response will vary|Responses? (?:may|should|will) differ)' r'[^\n]*\n+', re.IGNORECASE, ) def _truncate_hallucination(text: str) -> str: for pattern in STOP_PATTERNS: idx = text.find(pattern) if idx > len(text) * 0.15: text = text[:idx] text = _PREAMBLE_RE.sub('', text) return text.strip() def _trim_incomplete(text: str) -> str: if not text: return text if not text.endswith((".", "!", "?")): last_end = max(text.rfind("."), text.rfind("!"), text.rfind("?")) if last_end > 0: text = text[:last_end + 1] return text def _split_into_sections(text: str) -> list[tuple[str, str]]: """Split text into (header, body) pairs. First section may have empty header.""" parts = _SECTION_RE.split(text) sections: list[tuple[str, str]] = [] if parts and not _SECTION_RE.match(parts[0].strip() + ":"): first = parts.pop(0).strip() if first: sections.append(("", first)) i = 0 while i < len(parts): header = parts[i].strip().rstrip(":") body = parts[i + 1].strip() if i + 1 < len(parts) else "" if header and body: sections.append((header, body)) elif header and not body: sections.append((header, "")) i += 2 return sections def _expand_inline_lists(text: str) -> str: """Convert 'such as X, Y, Z, and W' into bullet points.""" def _replacer(m: re.Match) -> str: items_str = m.group(1) items = re.split(r',\s*(?:and\s+)?', items_str) items = [it.strip().rstrip(".").strip() for it in items if it.strip()] if len(items) < 2: return m.group(0) prefix = m.group(0)[:m.start(1) - m.start(0)] bullet_block = "\n".join(f"- {it.capitalize()}" for it in items) return f"{prefix.rstrip()}\n\n{bullet_block}" return _INLINE_LIST_RE.sub(_replacer, text) def _sentences_to_bullets(text: str) -> str: """Convert run-on sentences into bullet points when there are many comma-separated or short-sentence items.""" sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) < 3: return text bullets: list[str] = [] for s in sentences: s = s.strip().rstrip(".") if not s: continue parts = s.split(":", 1) if len(parts) == 2 and len(parts[0]) < 40: bullets.append(f"- **{parts[0].strip()}**: {parts[1].strip()}") else: bullets.append(f"- {s}") return "\n".join(bullets) def _fix_numbered_list(text: str) -> str: """Fix broken inline numbered lists (1. X 2. Y) into proper line-separated format.""" if not re.search(r'\d+\.\s', text): return text items = re.split(r'(?<=[.!?])\s*(?=\d+\.\s)', text) if len(items) < 2: items = re.split(r'\s+(?=\d+\.\s)', text) if len(items) < 2: return text result: list[str] = [] counter = 0 for item in items: item = item.strip() if not item: continue cleaned = re.sub(r'^\d+\.\s*', '', item) if cleaned != item: counter += 1 result.append(f"{counter}. {cleaned}") else: result.append(item) return "\n".join(result) def format_response(text: str) -> str: """Post-process model output into clean, GPT-style structured formatting.""" text = _truncate_hallucination(text) if not text: return text text = _trim_incomplete(text) text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\s{2,}-\s+', '\n- ', text) if '\n\n' in text and re.search(r'\n- ', text): return text.strip() sections = _split_into_sections(text) if len(sections) <= 1 and sections: header, body = sections[0] body = _fix_numbered_list(body) if re.search(r'\d+\.\s', body) and "\n" in body: return body.strip() body = _expand_inline_lists(body) if "\n- " in body: return body.strip() sentences = re.split(r'(?<=[.!?])\s+', body) if len(sentences) >= 4: intro = sentences[0] rest_text = " ".join(sentences[1:]) bullets = _sentences_to_bullets(rest_text) return f"{intro}\n\n{bullets}".strip() return body.strip() output_parts: list[str] = [] for header, body in sections: if not header and body: output_parts.append(body) continue body = _fix_numbered_list(body) body = _expand_inline_lists(body) if "\n" not in body and not body.startswith("-"): sentences = re.split(r'(?<=[.!?])\s+', body) if len(sentences) >= 3: body = _sentences_to_bullets(body) section_text = f"**{header}**\n\n{body}" if body else f"**{header}**" output_parts.append(section_text) result = "\n\n".join(output_parts) if not re.search( r'(?i)disclaimer|informational purposes|not.{0,20}replace.{0,30}medical advice', result, ): result += ( "\n\n⚠️ *This information is for educational purposes only " "and should not replace professional medical advice.*" ) return result.strip() engine = None @asynccontextmanager async def lifespan(app: FastAPI): global engine print(f"Loading model from {MODEL_PATH} via vLLM ...") engine_args = AsyncEngineArgs( model=MODEL_PATH, dtype="bfloat16", max_model_len=2048, gpu_memory_utilization=0.40, enforce_eager=True, ) engine = AsyncLLMEngine.from_engine_args(engine_args) print(f"vLLM engine ready") print(f"Dashboard available at http://{HOST}:{PORT}/") yield engine = None torch.cuda.empty_cache() app = FastAPI( title="MedFound Medical Chatbot", description="Medical AI chatbot powered by MedFound-Llama3-8B (vLLM)", version="2.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Message(BaseModel): role: str = Field(..., pattern="^(user|assistant|system)$") content: str class ChatRequest(BaseModel): messages: list[Message] max_new_tokens: int = Field(default=512, ge=1, le=2048) temperature: float = Field(default=0.7, ge=0.01, le=2.0) top_p: float = Field(default=0.9, ge=0.0, le=1.0) stream: bool = False class ChatResponse(BaseModel): id: str content: str finish_reason: str usage: dict def build_prompt(messages: list[Message]) -> str: parts: list[str] = [] has_system = any(m.role == "system" for m in messages) if not has_system: parts.append(f"### System: {SYSTEM_PROMPT}\n") for msg in messages: if msg.role == "system": parts.append(f"### System: {msg.content}\n") elif msg.role == "user": parts.append(f"### User: {msg.content}\n") elif msg.role == "assistant": parts.append(f"### Assistant: {msg.content}\n") parts.append("### Assistant:") return "\n".join(parts) @app.get("/health") async def health(): return { "status": "ok", "model": MODEL_PATH, "engine": "vLLM", "gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 1) if torch.cuda.is_available() else None, } @app.post("/v1/chat", response_model=ChatResponse) async def chat(req: ChatRequest): if not req.messages or req.messages[-1].role != "user": raise HTTPException(400, "Last message must be from the user.") prompt = build_prompt(req.messages) request_id = f"med-{uuid.uuid4().hex[:12]}" sampling_params = SamplingParams( max_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, stop=STOP_PATTERNS, repetition_penalty=1.15, ) if req.stream: return _stream_response(request_id, prompt, sampling_params) full_text = "" prompt_tokens = 0 completion_tokens = 0 async for result in engine.generate(prompt, sampling_params, request_id): final = result output = final.outputs[0] full_text = format_response(output.text) prompt_tokens = len(final.prompt_token_ids) completion_tokens = len(output.token_ids) return ChatResponse( id=request_id, content=full_text, finish_reason=output.finish_reason or "stop", usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, ) def _stream_response(request_id: str, prompt: str, sampling_params: SamplingParams): async def token_generator(): accumulated = "" sent_len = 0 async for result in engine.generate(prompt, sampling_params, request_id): output = result.outputs[0] accumulated = output.text formatted = format_response(accumulated) new_text = formatted[sent_len:] if new_text: sent_len = len(formatted) yield f"data: {new_text}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(token_generator(), media_type="text/event-stream") @app.post("/v1/diagnose") async def diagnose(req: ChatRequest): if not req.messages or req.messages[-1].role != "user": raise HTTPException(400, "Last message must be from the user.") last = req.messages[-1] last.content = ( f"{last.content}\n\n" "Please provide a detailed and comprehensive diagnostic analysis of this medical record." ) return await chat(req) @app.get("/", include_in_schema=False) async def serve_dashboard(): dashboard_path = os.path.join(DASHBOARD_DIR, "index.html") with open(dashboard_path, "r") as f: html = f.read() return HTMLResponse(content=html) if __name__ == "__main__": uvicorn.run( app, host=HOST, port=PORT, log_level="info", )