| | import os |
| | import sys |
| |
|
| | SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | SOURCE_DIR = os.path.join(SERVER_DIR, "source") |
| | if os.path.isdir(SOURCE_DIR): |
| | sys.path.insert(0, SOURCE_DIR) |
| |
|
| | import re |
| | import uuid |
| | from contextlib import asynccontextmanager |
| |
|
| | import torch |
| | import uvicorn |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from starlette.responses import HTMLResponse, StreamingResponse |
| | from pydantic import BaseModel, Field |
| | from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams |
| |
|
| | MODEL_PATH = os.environ.get("MODEL_PATH", "/root/openbiollm-model") |
| | HOST = os.environ.get("HOST", "0.0.0.0") |
| | PORT = int(os.environ.get("PORT", "8001")) |
| | DASHBOARD_DIR = os.path.join(SERVER_DIR, "dashboard") |
| |
|
| | SYSTEM_PROMPT = ( |
| | "You are OpenBioLLM, a medical AI assistant. You provide helpful, accurate, " |
| | "and evidence-based medical information.\n\n" |
| | "Response format rules:\n\n" |
| | "1. Start with a clear, direct one-sentence answer to the question.\n\n" |
| | "2. Then organize the rest of your response into labeled sections. " |
| | "Use section headers like 'Definition:', 'Common uses:', 'Drug class:', " |
| | "'Symptoms:', 'Causes:', 'Treatment:', 'Safety:', 'Key points:' etc. " |
| | "Put each section header on its own line followed by a newline.\n\n" |
| | "3. Under each section, list items one per line using '- ' bullet points.\n\n" |
| | "4. Leave a blank line between each section.\n\n" |
| | "5. Keep each bullet point short and clear (one idea per bullet).\n\n" |
| | "6. At the end, add a 'Safety note:' or 'Important:' section for warnings.\n\n" |
| | "7. Stay on topic. Only answer what was asked. " |
| | "Do NOT generate unrelated content, fictional patient cases, or diagnosis codes.\n\n" |
| | "8. Stop when you have fully answered. Do not keep writing.\n\n" |
| | "9. End with a one-line disclaimer that this is for informational purposes only." |
| | ) |
| |
|
| | STOP_PATTERNS = [ |
| | "The following is a case", |
| | "The patient is a", |
| | "Diagnosis code:", |
| | "Treatment code:", |
| | "The following sections provide", |
| | "## Further reading", |
| | "## References", |
| | "Further reading", |
| | "End of interaction", |
| | "System Response", |
| | "The following is an example", |
| | "This sample response", |
| | "<|eot_id|>", |
| | "<|start_header_id|>", |
| | ] |
| |
|
| |
|
| | _SECTION_RE = re.compile( |
| | r'(?:^|\.\s+|\n\s*)' |
| | r'((?:Definition|Common uses?|Drug class|Symptoms?|Causes?|Diagnosis' |
| | r'|Treatments?|Safety(?: note)?|Important|Mechanism|Side effects?' |
| | r'|Precautions?|Dosage|Key points?|Overview|Risk factors?' |
| | r'|Complications?|Prevention|When to see a doctor|Warning' |
| | r'|How it works|What it(?:\'s| is) used for|Disclaimer)\s*:\s*)', |
| | re.IGNORECASE, |
| | ) |
| |
|
| | _INLINE_LIST_RE = re.compile( |
| | r'(?:(?:such as|including|like|e\.g\.|for example|include)\s+)' |
| | r'([^.!?]{10,}(?:,\s*(?:and\s+)?[^.!?]+)+)', |
| | re.IGNORECASE, |
| | ) |
| |
|
| | _PREAMBLE_RE = re.compile( |
| | r'^(?:[^\n]*\n)*?[^\n]*' |
| | r'(?:System Response|sample response|example (?:system )?response|' |
| | r'The exact response will vary|Responses? (?:may|should|will) differ)' |
| | r'[^\n]*\n+', |
| | re.IGNORECASE, |
| | ) |
| |
|
| |
|
| | def _truncate_hallucination(text: str) -> str: |
| | for pattern in STOP_PATTERNS: |
| | idx = text.find(pattern) |
| | if idx > len(text) * 0.15: |
| | text = text[:idx] |
| |
|
| | text = _PREAMBLE_RE.sub('', text) |
| | return text.strip() |
| |
|
| |
|
| | def _trim_incomplete(text: str) -> str: |
| | if not text: |
| | return text |
| | if not text.endswith((".", "!", "?")): |
| | last_end = max(text.rfind("."), text.rfind("!"), text.rfind("?")) |
| | if last_end > 0: |
| | text = text[:last_end + 1] |
| | return text |
| |
|
| |
|
| | def _split_into_sections(text: str) -> list[tuple[str, str]]: |
| | parts = _SECTION_RE.split(text) |
| | sections: list[tuple[str, str]] = [] |
| |
|
| | if parts and not _SECTION_RE.match(parts[0].strip() + ":"): |
| | first = parts.pop(0).strip() |
| | if first: |
| | sections.append(("", first)) |
| |
|
| | i = 0 |
| | while i < len(parts): |
| | header = parts[i].strip().rstrip(":") |
| | body = parts[i + 1].strip() if i + 1 < len(parts) else "" |
| | if header and body: |
| | sections.append((header, body)) |
| | elif header and not body: |
| | sections.append((header, "")) |
| | i += 2 |
| |
|
| | return sections |
| |
|
| |
|
| | def _expand_inline_lists(text: str) -> str: |
| | def _replacer(m: re.Match) -> str: |
| | items_str = m.group(1) |
| | items = re.split(r',\s*(?:and\s+)?', items_str) |
| | items = [it.strip().rstrip(".").strip() for it in items if it.strip()] |
| | if len(items) < 2: |
| | return m.group(0) |
| | prefix = m.group(0)[:m.start(1) - m.start(0)] |
| | bullet_block = "\n".join(f"- {it.capitalize()}" for it in items) |
| | return f"{prefix.rstrip()}\n\n{bullet_block}" |
| |
|
| | return _INLINE_LIST_RE.sub(_replacer, text) |
| |
|
| |
|
| | def _sentences_to_bullets(text: str) -> str: |
| | sentences = re.split(r'(?<=[.!?])\s+', text) |
| | if len(sentences) < 3: |
| | return text |
| |
|
| | bullets: list[str] = [] |
| | for s in sentences: |
| | s = s.strip().rstrip(".") |
| | if not s: |
| | continue |
| | parts = s.split(":", 1) |
| | if len(parts) == 2 and len(parts[0]) < 40: |
| | bullets.append(f"- **{parts[0].strip()}**: {parts[1].strip()}") |
| | else: |
| | bullets.append(f"- {s}") |
| |
|
| | return "\n".join(bullets) |
| |
|
| |
|
| | def _fix_numbered_list(text: str) -> str: |
| | if not re.search(r'\d+\.\s', text): |
| | return text |
| |
|
| | items = re.split(r'(?<=[.!?])\s*(?=\d+\.\s)', text) |
| | if len(items) < 2: |
| | items = re.split(r'\s+(?=\d+\.\s)', text) |
| | if len(items) < 2: |
| | return text |
| |
|
| | result: list[str] = [] |
| | counter = 0 |
| | for item in items: |
| | item = item.strip() |
| | if not item: |
| | continue |
| | cleaned = re.sub(r'^\d+\.\s*', '', item) |
| | if cleaned != item: |
| | counter += 1 |
| | result.append(f"{counter}. {cleaned}") |
| | else: |
| | result.append(item) |
| |
|
| | return "\n".join(result) |
| |
|
| |
|
| | def format_response(text: str) -> str: |
| | text = _truncate_hallucination(text) |
| | if not text: |
| | return text |
| |
|
| | text = _trim_incomplete(text) |
| | text = re.sub(r'\n{3,}', '\n\n', text) |
| | text = re.sub(r'\s{2,}-\s+', '\n- ', text) |
| |
|
| | if '\n\n' in text and re.search(r'\n- ', text): |
| | return text.strip() |
| |
|
| | sections = _split_into_sections(text) |
| |
|
| | if len(sections) <= 1 and sections: |
| | header, body = sections[0] |
| | body = _fix_numbered_list(body) |
| |
|
| | if re.search(r'\d+\.\s', body) and "\n" in body: |
| | return body.strip() |
| |
|
| | body = _expand_inline_lists(body) |
| | if "\n- " in body: |
| | return body.strip() |
| |
|
| | sentences = re.split(r'(?<=[.!?])\s+', body) |
| | if len(sentences) >= 4: |
| | intro = sentences[0] |
| | rest_text = " ".join(sentences[1:]) |
| | bullets = _sentences_to_bullets(rest_text) |
| | return f"{intro}\n\n{bullets}".strip() |
| |
|
| | return body.strip() |
| |
|
| | output_parts: list[str] = [] |
| | for header, body in sections: |
| | if not header and body: |
| | output_parts.append(body) |
| | continue |
| |
|
| | body = _fix_numbered_list(body) |
| | body = _expand_inline_lists(body) |
| |
|
| | if "\n" not in body and not body.startswith("-"): |
| | sentences = re.split(r'(?<=[.!?])\s+', body) |
| | if len(sentences) >= 3: |
| | body = _sentences_to_bullets(body) |
| |
|
| | section_text = f"**{header}**\n\n{body}" if body else f"**{header}**" |
| | output_parts.append(section_text) |
| |
|
| | result = "\n\n".join(output_parts) |
| |
|
| | if not re.search( |
| | r'(?i)disclaimer|informational purposes|not.{0,20}replace.{0,30}medical advice', |
| | result, |
| | ): |
| | result += ( |
| | "\n\n⚠️ *This information is for educational purposes only " |
| | "and should not replace professional medical advice.*" |
| | ) |
| |
|
| | return result.strip() |
| |
|
| |
|
| | engine = None |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | global engine |
| | print(f"Loading model from {MODEL_PATH} via vLLM ...") |
| |
|
| | engine_args = AsyncEngineArgs( |
| | model=MODEL_PATH, |
| | dtype="bfloat16", |
| | max_model_len=2048, |
| | gpu_memory_utilization=0.40, |
| | enforce_eager=True, |
| | ) |
| | engine = AsyncLLMEngine.from_engine_args(engine_args) |
| |
|
| | print(f"vLLM engine ready") |
| | print(f"Dashboard available at http://{HOST}:{PORT}/") |
| | yield |
| | engine = None |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | app = FastAPI( |
| | title="OpenBioLLM Medical Chatbot", |
| | description="Medical AI chatbot powered by OpenBioLLM-8B (vLLM)", |
| | version="1.0.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | class Message(BaseModel): |
| | role: str = Field(..., pattern="^(user|assistant|system)$") |
| | content: str |
| |
|
| |
|
| | class ChatRequest(BaseModel): |
| | messages: list[Message] |
| | max_new_tokens: int = Field(default=512, ge=1, le=2048) |
| | temperature: float = Field(default=0.7, ge=0.01, le=2.0) |
| | top_p: float = Field(default=0.9, ge=0.0, le=1.0) |
| | stream: bool = False |
| |
|
| |
|
| | class ChatResponse(BaseModel): |
| | id: str |
| | content: str |
| | finish_reason: str |
| | usage: dict |
| |
|
| |
|
| | def build_prompt(messages: list[Message]) -> str: |
| | """Build Llama 3 chat-format prompt.""" |
| | parts: list[str] = ["<|begin_of_text|>"] |
| |
|
| | has_system = any(m.role == "system" for m in messages) |
| | if not has_system: |
| | parts.append( |
| | "<|start_header_id|>system<|end_header_id|>\n\n" |
| | f"{SYSTEM_PROMPT}<|eot_id|>" |
| | ) |
| |
|
| | for msg in messages: |
| | parts.append( |
| | f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n" |
| | f"{msg.content}<|eot_id|>" |
| | ) |
| |
|
| | parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n") |
| | return "".join(parts) |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "ok", |
| | "model": MODEL_PATH, |
| | "engine": "vLLM", |
| | "gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 1) |
| | if torch.cuda.is_available() |
| | else None, |
| | } |
| |
|
| |
|
| | @app.post("/v1/chat", response_model=ChatResponse) |
| | async def chat(req: ChatRequest): |
| | if not req.messages or req.messages[-1].role != "user": |
| | raise HTTPException(400, "Last message must be from the user.") |
| |
|
| | prompt = build_prompt(req.messages) |
| | request_id = f"bio-{uuid.uuid4().hex[:12]}" |
| |
|
| | sampling_params = SamplingParams( |
| | max_tokens=req.max_new_tokens, |
| | temperature=req.temperature, |
| | top_p=req.top_p, |
| | stop=STOP_PATTERNS, |
| | repetition_penalty=1.15, |
| | ) |
| |
|
| | if req.stream: |
| | return _stream_response(request_id, prompt, sampling_params) |
| |
|
| | full_text = "" |
| | prompt_tokens = 0 |
| | completion_tokens = 0 |
| |
|
| | async for result in engine.generate(prompt, sampling_params, request_id): |
| | final = result |
| |
|
| | output = final.outputs[0] |
| | full_text = format_response(output.text) |
| | prompt_tokens = len(final.prompt_token_ids) |
| | completion_tokens = len(output.token_ids) |
| |
|
| | return ChatResponse( |
| | id=request_id, |
| | content=full_text, |
| | finish_reason=output.finish_reason or "stop", |
| | usage={ |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": completion_tokens, |
| | "total_tokens": prompt_tokens + completion_tokens, |
| | }, |
| | ) |
| |
|
| |
|
| | def _stream_response(request_id: str, prompt: str, sampling_params: SamplingParams): |
| | async def token_generator(): |
| | accumulated = "" |
| | sent_len = 0 |
| | async for result in engine.generate(prompt, sampling_params, request_id): |
| | output = result.outputs[0] |
| | accumulated = output.text |
| |
|
| | formatted = format_response(accumulated) |
| | new_text = formatted[sent_len:] |
| | if new_text: |
| | sent_len = len(formatted) |
| | yield f"data: {new_text}\n\n" |
| | yield "data: [DONE]\n\n" |
| |
|
| | return StreamingResponse(token_generator(), media_type="text/event-stream") |
| |
|
| |
|
| | @app.post("/v1/diagnose") |
| | async def diagnose(req: ChatRequest): |
| | if not req.messages or req.messages[-1].role != "user": |
| | raise HTTPException(400, "Last message must be from the user.") |
| |
|
| | last = req.messages[-1] |
| | last.content = ( |
| | f"{last.content}\n\n" |
| | "Please provide a detailed and comprehensive diagnostic analysis of this medical record." |
| | ) |
| | return await chat(req) |
| |
|
| |
|
| | @app.get("/", include_in_schema=False) |
| | async def serve_dashboard(): |
| | dashboard_path = os.path.join(DASHBOARD_DIR, "index.html") |
| | with open(dashboard_path, "r") as f: |
| | html = f.read() |
| | return HTMLResponse(content=html) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run( |
| | app, |
| | host=HOST, |
| | port=PORT, |
| | log_level="info", |
| | ) |
| |
|