import os import sys SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) SOURCE_DIR = os.path.join(SERVER_DIR, "source") if os.path.isdir(SOURCE_DIR): sys.path.insert(0, SOURCE_DIR) import re import uuid from contextlib import asynccontextmanager import torch import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from starlette.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel, Field from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams MODEL_PATH = os.environ.get("MODEL_PATH", "/root/openbiollm-model") HOST = os.environ.get("HOST", "0.0.0.0") PORT = int(os.environ.get("PORT", "8001")) DASHBOARD_DIR = os.path.join(SERVER_DIR, "dashboard") SYSTEM_PROMPT = ( "You are OpenBioLLM, a medical AI assistant. You provide helpful, accurate, " "and evidence-based medical information.\n\n" "Response format rules:\n\n" "1. Start with a clear, direct one-sentence answer to the question.\n\n" "2. Then organize the rest of your response into labeled sections. " "Use section headers like 'Definition:', 'Common uses:', 'Drug class:', " "'Symptoms:', 'Causes:', 'Treatment:', 'Safety:', 'Key points:' etc. " "Put each section header on its own line followed by a newline.\n\n" "3. Under each section, list items one per line using '- ' bullet points.\n\n" "4. Leave a blank line between each section.\n\n" "5. Keep each bullet point short and clear (one idea per bullet).\n\n" "6. At the end, add a 'Safety note:' or 'Important:' section for warnings.\n\n" "7. Stay on topic. Only answer what was asked. " "Do NOT generate unrelated content, fictional patient cases, or diagnosis codes.\n\n" "8. Stop when you have fully answered. Do not keep writing.\n\n" "9. End with a one-line disclaimer that this is for informational purposes only." ) STOP_PATTERNS = [ "The following is a case", "The patient is a", "Diagnosis code:", "Treatment code:", "The following sections provide", "## Further reading", "## References", "Further reading", "End of interaction", "System Response", "The following is an example", "This sample response", "<|eot_id|>", "<|start_header_id|>", ] _SECTION_RE = re.compile( r'(?:^|\.\s+|\n\s*)' r'((?:Definition|Common uses?|Drug class|Symptoms?|Causes?|Diagnosis' r'|Treatments?|Safety(?: note)?|Important|Mechanism|Side effects?' r'|Precautions?|Dosage|Key points?|Overview|Risk factors?' r'|Complications?|Prevention|When to see a doctor|Warning' r'|How it works|What it(?:\'s| is) used for|Disclaimer)\s*:\s*)', re.IGNORECASE, ) _INLINE_LIST_RE = re.compile( r'(?:(?:such as|including|like|e\.g\.|for example|include)\s+)' r'([^.!?]{10,}(?:,\s*(?:and\s+)?[^.!?]+)+)', re.IGNORECASE, ) _PREAMBLE_RE = re.compile( r'^(?:[^\n]*\n)*?[^\n]*' r'(?:System Response|sample response|example (?:system )?response|' r'The exact response will vary|Responses? (?:may|should|will) differ)' r'[^\n]*\n+', re.IGNORECASE, ) def _truncate_hallucination(text: str) -> str: for pattern in STOP_PATTERNS: idx = text.find(pattern) if idx > len(text) * 0.15: text = text[:idx] text = _PREAMBLE_RE.sub('', text) return text.strip() def _trim_incomplete(text: str) -> str: if not text: return text if not text.endswith((".", "!", "?")): last_end = max(text.rfind("."), text.rfind("!"), text.rfind("?")) if last_end > 0: text = text[:last_end + 1] return text def _split_into_sections(text: str) -> list[tuple[str, str]]: parts = _SECTION_RE.split(text) sections: list[tuple[str, str]] = [] if parts and not _SECTION_RE.match(parts[0].strip() + ":"): first = parts.pop(0).strip() if first: sections.append(("", first)) i = 0 while i < len(parts): header = parts[i].strip().rstrip(":") body = parts[i + 1].strip() if i + 1 < len(parts) else "" if header and body: sections.append((header, body)) elif header and not body: sections.append((header, "")) i += 2 return sections def _expand_inline_lists(text: str) -> str: def _replacer(m: re.Match) -> str: items_str = m.group(1) items = re.split(r',\s*(?:and\s+)?', items_str) items = [it.strip().rstrip(".").strip() for it in items if it.strip()] if len(items) < 2: return m.group(0) prefix = m.group(0)[:m.start(1) - m.start(0)] bullet_block = "\n".join(f"- {it.capitalize()}" for it in items) return f"{prefix.rstrip()}\n\n{bullet_block}" return _INLINE_LIST_RE.sub(_replacer, text) def _sentences_to_bullets(text: str) -> str: sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) < 3: return text bullets: list[str] = [] for s in sentences: s = s.strip().rstrip(".") if not s: continue parts = s.split(":", 1) if len(parts) == 2 and len(parts[0]) < 40: bullets.append(f"- **{parts[0].strip()}**: {parts[1].strip()}") else: bullets.append(f"- {s}") return "\n".join(bullets) def _fix_numbered_list(text: str) -> str: if not re.search(r'\d+\.\s', text): return text items = re.split(r'(?<=[.!?])\s*(?=\d+\.\s)', text) if len(items) < 2: items = re.split(r'\s+(?=\d+\.\s)', text) if len(items) < 2: return text result: list[str] = [] counter = 0 for item in items: item = item.strip() if not item: continue cleaned = re.sub(r'^\d+\.\s*', '', item) if cleaned != item: counter += 1 result.append(f"{counter}. {cleaned}") else: result.append(item) return "\n".join(result) def format_response(text: str) -> str: text = _truncate_hallucination(text) if not text: return text text = _trim_incomplete(text) text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\s{2,}-\s+', '\n- ', text) if '\n\n' in text and re.search(r'\n- ', text): return text.strip() sections = _split_into_sections(text) if len(sections) <= 1 and sections: header, body = sections[0] body = _fix_numbered_list(body) if re.search(r'\d+\.\s', body) and "\n" in body: return body.strip() body = _expand_inline_lists(body) if "\n- " in body: return body.strip() sentences = re.split(r'(?<=[.!?])\s+', body) if len(sentences) >= 4: intro = sentences[0] rest_text = " ".join(sentences[1:]) bullets = _sentences_to_bullets(rest_text) return f"{intro}\n\n{bullets}".strip() return body.strip() output_parts: list[str] = [] for header, body in sections: if not header and body: output_parts.append(body) continue body = _fix_numbered_list(body) body = _expand_inline_lists(body) if "\n" not in body and not body.startswith("-"): sentences = re.split(r'(?<=[.!?])\s+', body) if len(sentences) >= 3: body = _sentences_to_bullets(body) section_text = f"**{header}**\n\n{body}" if body else f"**{header}**" output_parts.append(section_text) result = "\n\n".join(output_parts) if not re.search( r'(?i)disclaimer|informational purposes|not.{0,20}replace.{0,30}medical advice', result, ): result += ( "\n\n⚠️ *This information is for educational purposes only " "and should not replace professional medical advice.*" ) return result.strip() engine = None @asynccontextmanager async def lifespan(app: FastAPI): global engine print(f"Loading model from {MODEL_PATH} via vLLM ...") engine_args = AsyncEngineArgs( model=MODEL_PATH, dtype="bfloat16", max_model_len=2048, gpu_memory_utilization=0.40, enforce_eager=True, ) engine = AsyncLLMEngine.from_engine_args(engine_args) print(f"vLLM engine ready") print(f"Dashboard available at http://{HOST}:{PORT}/") yield engine = None torch.cuda.empty_cache() app = FastAPI( title="OpenBioLLM Medical Chatbot", description="Medical AI chatbot powered by OpenBioLLM-8B (vLLM)", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Message(BaseModel): role: str = Field(..., pattern="^(user|assistant|system)$") content: str class ChatRequest(BaseModel): messages: list[Message] max_new_tokens: int = Field(default=512, ge=1, le=2048) temperature: float = Field(default=0.7, ge=0.01, le=2.0) top_p: float = Field(default=0.9, ge=0.0, le=1.0) stream: bool = False class ChatResponse(BaseModel): id: str content: str finish_reason: str usage: dict def build_prompt(messages: list[Message]) -> str: """Build Llama 3 chat-format prompt.""" parts: list[str] = ["<|begin_of_text|>"] has_system = any(m.role == "system" for m in messages) if not has_system: parts.append( "<|start_header_id|>system<|end_header_id|>\n\n" f"{SYSTEM_PROMPT}<|eot_id|>" ) for msg in messages: parts.append( f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n" f"{msg.content}<|eot_id|>" ) parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n") return "".join(parts) @app.get("/health") async def health(): return { "status": "ok", "model": MODEL_PATH, "engine": "vLLM", "gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 1) if torch.cuda.is_available() else None, } @app.post("/v1/chat", response_model=ChatResponse) async def chat(req: ChatRequest): if not req.messages or req.messages[-1].role != "user": raise HTTPException(400, "Last message must be from the user.") prompt = build_prompt(req.messages) request_id = f"bio-{uuid.uuid4().hex[:12]}" sampling_params = SamplingParams( max_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, stop=STOP_PATTERNS, repetition_penalty=1.15, ) if req.stream: return _stream_response(request_id, prompt, sampling_params) full_text = "" prompt_tokens = 0 completion_tokens = 0 async for result in engine.generate(prompt, sampling_params, request_id): final = result output = final.outputs[0] full_text = format_response(output.text) prompt_tokens = len(final.prompt_token_ids) completion_tokens = len(output.token_ids) return ChatResponse( id=request_id, content=full_text, finish_reason=output.finish_reason or "stop", usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, ) def _stream_response(request_id: str, prompt: str, sampling_params: SamplingParams): async def token_generator(): accumulated = "" sent_len = 0 async for result in engine.generate(prompt, sampling_params, request_id): output = result.outputs[0] accumulated = output.text formatted = format_response(accumulated) new_text = formatted[sent_len:] if new_text: sent_len = len(formatted) yield f"data: {new_text}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(token_generator(), media_type="text/event-stream") @app.post("/v1/diagnose") async def diagnose(req: ChatRequest): if not req.messages or req.messages[-1].role != "user": raise HTTPException(400, "Last message must be from the user.") last = req.messages[-1] last.content = ( f"{last.content}\n\n" "Please provide a detailed and comprehensive diagnostic analysis of this medical record." ) return await chat(req) @app.get("/", include_in_schema=False) async def serve_dashboard(): dashboard_path = os.path.join(DASHBOARD_DIR, "index.html") with open(dashboard_path, "r") as f: html = f.read() return HTMLResponse(content=html) if __name__ == "__main__": uvicorn.run( app, host=HOST, port=PORT, log_level="info", )