server / server.py
Harmony18090's picture
Upload server.py with huggingface_hub
f3e3988 verified
raw
history blame
13.3 kB
import os
import sys
SERVER_DIR = os.path.dirname(os.path.abspath(__file__))
SOURCE_DIR = os.path.join(SERVER_DIR, "source")
if os.path.isdir(SOURCE_DIR):
sys.path.insert(0, SOURCE_DIR)
import re
import uuid
import asyncio
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel, Field
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
MODEL_PATH = os.environ.get("MODEL_PATH", "/root/model")
HOST = os.environ.get("HOST", "0.0.0.0")
PORT = int(os.environ.get("PORT", "8000"))
DASHBOARD_DIR = os.path.join(SERVER_DIR, "dashboard")
SYSTEM_PROMPT = (
"You are MedFound, a medical AI assistant. You provide helpful, accurate, "
"and evidence-based medical information.\n\n"
"Response format rules:\n\n"
"1. Start with a clear, direct one-sentence answer to the question.\n\n"
"2. Then organize the rest of your response into labeled sections. "
"Use section headers like 'Definition:', 'Common uses:', 'Drug class:', "
"'Symptoms:', 'Causes:', 'Treatment:', 'Safety:', 'Key points:' etc. "
"Put each section header on its own line followed by a newline.\n\n"
"3. Under each section, list items one per line using '- ' bullet points.\n\n"
"4. Leave a blank line between each section.\n\n"
"5. Keep each bullet point short and clear (one idea per bullet).\n\n"
"6. At the end, add a 'Safety note:' or 'Important:' section for warnings.\n\n"
"7. Stay on topic. Only answer what was asked. "
"Do NOT generate unrelated content, fictional patient cases, or diagnosis codes.\n\n"
"8. Stop when you have fully answered. Do not keep writing.\n\n"
"9. End with a one-line disclaimer that this is for informational purposes only."
)
STOP_PATTERNS = [
"The following is a case",
"The patient is a",
"Diagnosis code:",
"Treatment code:",
"The following sections provide",
"### User:",
"### System:",
"## Further reading",
"## References",
"Further reading",
"End of interaction",
"System Response",
"The following is an example",
"This sample response",
"Asbury et al",
"Lewis & Loftus",
]
_SECTION_RE = re.compile(
r'(?:^|\.\s+|\n\s*)'
r'((?:Definition|Common uses?|Drug class|Symptoms?|Causes?|Diagnosis'
r'|Treatments?|Safety(?: note)?|Important|Mechanism|Side effects?'
r'|Precautions?|Dosage|Key points?|Overview|Risk factors?'
r'|Complications?|Prevention|When to see a doctor|Warning'
r'|How it works|What it(?:\'s| is) used for|Disclaimer)\s*:\s*)',
re.IGNORECASE,
)
_INLINE_LIST_RE = re.compile(
r'(?:(?:such as|including|like|e\.g\.|for example|include)\s+)'
r'([^.!?]{10,}(?:,\s*(?:and\s+)?[^.!?]+)+)',
re.IGNORECASE,
)
_PREAMBLE_RE = re.compile(
r'^(?:[^\n]*\n)*?[^\n]*'
r'(?:System Response|sample response|example (?:system )?response|'
r'The exact response will vary|Responses? (?:may|should|will) differ)'
r'[^\n]*\n+',
re.IGNORECASE,
)
def _truncate_hallucination(text: str) -> str:
for pattern in STOP_PATTERNS:
idx = text.find(pattern)
if idx > len(text) * 0.15:
text = text[:idx]
text = _PREAMBLE_RE.sub('', text)
return text.strip()
def _trim_incomplete(text: str) -> str:
if not text:
return text
if not text.endswith((".", "!", "?")):
last_end = max(text.rfind("."), text.rfind("!"), text.rfind("?"))
if last_end > 0:
text = text[:last_end + 1]
return text
def _split_into_sections(text: str) -> list[tuple[str, str]]:
"""Split text into (header, body) pairs. First section may have empty header."""
parts = _SECTION_RE.split(text)
sections: list[tuple[str, str]] = []
if parts and not _SECTION_RE.match(parts[0].strip() + ":"):
first = parts.pop(0).strip()
if first:
sections.append(("", first))
i = 0
while i < len(parts):
header = parts[i].strip().rstrip(":")
body = parts[i + 1].strip() if i + 1 < len(parts) else ""
if header and body:
sections.append((header, body))
elif header and not body:
sections.append((header, ""))
i += 2
return sections
def _expand_inline_lists(text: str) -> str:
"""Convert 'such as X, Y, Z, and W' into bullet points."""
def _replacer(m: re.Match) -> str:
items_str = m.group(1)
items = re.split(r',\s*(?:and\s+)?', items_str)
items = [it.strip().rstrip(".").strip() for it in items if it.strip()]
if len(items) < 2:
return m.group(0)
prefix = m.group(0)[:m.start(1) - m.start(0)]
bullet_block = "\n".join(f"- {it.capitalize()}" for it in items)
return f"{prefix.rstrip()}\n\n{bullet_block}"
return _INLINE_LIST_RE.sub(_replacer, text)
def _sentences_to_bullets(text: str) -> str:
"""Convert run-on sentences into bullet points when there are many
comma-separated or short-sentence items."""
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) < 3:
return text
bullets: list[str] = []
for s in sentences:
s = s.strip().rstrip(".")
if not s:
continue
parts = s.split(":", 1)
if len(parts) == 2 and len(parts[0]) < 40:
bullets.append(f"- **{parts[0].strip()}**: {parts[1].strip()}")
else:
bullets.append(f"- {s}")
return "\n".join(bullets)
def _fix_numbered_list(text: str) -> str:
"""Fix broken inline numbered lists (1. X 2. Y) into proper line-separated format."""
if not re.search(r'\d+\.\s', text):
return text
items = re.split(r'(?<=[.!?])\s*(?=\d+\.\s)', text)
if len(items) < 2:
items = re.split(r'\s+(?=\d+\.\s)', text)
if len(items) < 2:
return text
result: list[str] = []
counter = 0
for item in items:
item = item.strip()
if not item:
continue
cleaned = re.sub(r'^\d+\.\s*', '', item)
if cleaned != item:
counter += 1
result.append(f"{counter}. {cleaned}")
else:
result.append(item)
return "\n".join(result)
def format_response(text: str) -> str:
"""Post-process model output into clean, GPT-style structured formatting."""
text = _truncate_hallucination(text)
if not text:
return text
text = _trim_incomplete(text)
text = re.sub(r'\n{3,}', '\n\n', text)
text = re.sub(r'\s{2,}-\s+', '\n- ', text)
if '\n\n' in text and re.search(r'\n- ', text):
return text.strip()
sections = _split_into_sections(text)
if len(sections) <= 1 and sections:
header, body = sections[0]
body = _fix_numbered_list(body)
if re.search(r'\d+\.\s', body) and "\n" in body:
return body.strip()
body = _expand_inline_lists(body)
if "\n- " in body:
return body.strip()
sentences = re.split(r'(?<=[.!?])\s+', body)
if len(sentences) >= 4:
intro = sentences[0]
rest_text = " ".join(sentences[1:])
bullets = _sentences_to_bullets(rest_text)
return f"{intro}\n\n{bullets}".strip()
return body.strip()
output_parts: list[str] = []
for header, body in sections:
if not header and body:
output_parts.append(body)
continue
body = _fix_numbered_list(body)
body = _expand_inline_lists(body)
if "\n" not in body and not body.startswith("-"):
sentences = re.split(r'(?<=[.!?])\s+', body)
if len(sentences) >= 3:
body = _sentences_to_bullets(body)
section_text = f"**{header}**\n\n{body}" if body else f"**{header}**"
output_parts.append(section_text)
result = "\n\n".join(output_parts)
if not re.search(
r'(?i)disclaimer|informational purposes|not.{0,20}replace.{0,30}medical advice',
result,
):
result += (
"\n\n⚠️ *This information is for educational purposes only "
"and should not replace professional medical advice.*"
)
return result.strip()
engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine
print(f"Loading model from {MODEL_PATH} via vLLM ...")
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
dtype="bfloat16",
max_model_len=2048,
gpu_memory_utilization=0.40,
enforce_eager=True,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
print(f"vLLM engine ready")
print(f"Dashboard available at http://{HOST}:{PORT}/")
yield
engine = None
torch.cuda.empty_cache()
app = FastAPI(
title="MedFound Medical Chatbot",
description="Medical AI chatbot powered by MedFound-Llama3-8B (vLLM)",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
role: str = Field(..., pattern="^(user|assistant|system)$")
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_new_tokens: int = Field(default=512, ge=1, le=2048)
temperature: float = Field(default=0.7, ge=0.01, le=2.0)
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
stream: bool = False
class ChatResponse(BaseModel):
id: str
content: str
finish_reason: str
usage: dict
def build_prompt(messages: list[Message]) -> str:
parts: list[str] = []
has_system = any(m.role == "system" for m in messages)
if not has_system:
parts.append(f"### System: {SYSTEM_PROMPT}\n")
for msg in messages:
if msg.role == "system":
parts.append(f"### System: {msg.content}\n")
elif msg.role == "user":
parts.append(f"### User: {msg.content}\n")
elif msg.role == "assistant":
parts.append(f"### Assistant: {msg.content}\n")
parts.append("### Assistant:")
return "\n".join(parts)
@app.get("/health")
async def health():
return {
"status": "ok",
"model": MODEL_PATH,
"engine": "vLLM",
"gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 1)
if torch.cuda.is_available()
else None,
}
@app.post("/v1/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
if not req.messages or req.messages[-1].role != "user":
raise HTTPException(400, "Last message must be from the user.")
prompt = build_prompt(req.messages)
request_id = f"med-{uuid.uuid4().hex[:12]}"
sampling_params = SamplingParams(
max_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
stop=STOP_PATTERNS,
repetition_penalty=1.15,
)
if req.stream:
return _stream_response(request_id, prompt, sampling_params)
full_text = ""
prompt_tokens = 0
completion_tokens = 0
async for result in engine.generate(prompt, sampling_params, request_id):
final = result
output = final.outputs[0]
full_text = format_response(output.text)
prompt_tokens = len(final.prompt_token_ids)
completion_tokens = len(output.token_ids)
return ChatResponse(
id=request_id,
content=full_text,
finish_reason=output.finish_reason or "stop",
usage={
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
)
def _stream_response(request_id: str, prompt: str, sampling_params: SamplingParams):
async def token_generator():
accumulated = ""
sent_len = 0
async for result in engine.generate(prompt, sampling_params, request_id):
output = result.outputs[0]
accumulated = output.text
formatted = format_response(accumulated)
new_text = formatted[sent_len:]
if new_text:
sent_len = len(formatted)
yield f"data: {new_text}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(token_generator(), media_type="text/event-stream")
@app.post("/v1/diagnose")
async def diagnose(req: ChatRequest):
if not req.messages or req.messages[-1].role != "user":
raise HTTPException(400, "Last message must be from the user.")
last = req.messages[-1]
last.content = (
f"{last.content}\n\n"
"Please provide a detailed and comprehensive diagnostic analysis of this medical record."
)
return await chat(req)
@app.get("/", include_in_schema=False)
async def serve_dashboard():
dashboard_path = os.path.join(DASHBOARD_DIR, "index.html")
with open(dashboard_path, "r") as f:
html = f.read()
return HTMLResponse(content=html)
if __name__ == "__main__":
uvicorn.run(
app,
host=HOST,
port=PORT,
log_level="info",
)