Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import time | |
| import sys | |
| import httpx | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import StreamingResponse | |
| import uvicorn | |
| import gradio as gr | |
| from openai import OpenAI, APIConnectionError | |
| import base64 | |
| from io import BytesIO | |
| # --- CONFIGURATION --- | |
| MODEL_ID = "numind/NuMarkdown-8B-Thinking" | |
| GPU_UTILIZATION = 0.90 | |
| MAX_MODEL_LEN = 32768 | |
| VLLM_PORT = 8000 | |
| EXPOSED_PORT = 7860 | |
| # --- STEP 1: LAUNCH vLLM (Background) --- | |
| def start_vllm(): | |
| if "VLLM_PID" in os.environ: | |
| return | |
| print(f"Starting vLLM server on port {VLLM_PORT}...") | |
| # JSON formatted limit string to fix parsing error | |
| limit_mm_config = '{"image": 1}' | |
| command = [ | |
| "vllm", "serve", MODEL_ID, | |
| "--host", "0.0.0.0", | |
| "--port", str(VLLM_PORT), | |
| "--trust-remote-code", | |
| "--gpu-memory-utilization", str(GPU_UTILIZATION), | |
| "--max-model-len", str(MAX_MODEL_LEN), | |
| "--dtype", "bfloat16", | |
| "--limit-mm-per-prompt", limit_mm_config | |
| ] | |
| # Redirect stdout/stderr to see download progress | |
| proc = subprocess.Popen(command, stdout=sys.stdout, stderr=sys.stderr) | |
| os.environ["VLLM_PID"] = str(proc.pid) | |
| # We do NOT block here anymore. We let vLLM load in the background | |
| # while the UI starts. This allows you to see the UI immediately. | |
| print("vLLM started in background. Please wait for model download...") | |
| start_vllm() | |
| # --- STEP 2: FASTAPI PROXY --- | |
| app = FastAPI() | |
| async def proxy_to_vllm(path: str, request: Request): | |
| target_url = f"http://localhost:{VLLM_PORT}/v1/{path}" | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| proxy_req = client.build_request( | |
| request.method, | |
| target_url, | |
| headers=request.headers.raw, | |
| content=await request.body(), | |
| timeout=300.0 | |
| ) | |
| r = await client.send(proxy_req, stream=True) | |
| return StreamingResponse( | |
| r.aiter_raw(), | |
| status_code=r.status_code, | |
| headers=r.headers, | |
| background=None | |
| ) | |
| except httpx.ConnectError: | |
| return JSONResponse(status_code=503, content={"error": "Model is still loading. Please wait."}) | |
| # --- STEP 3: GRADIO UI --- | |
| def run_ui_test(image, prompt): | |
| if image is None: | |
| return "⚠️ Please upload an image first." | |
| client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="EMPTY") | |
| # Encode Image | |
| try: | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| except Exception as e: | |
| return f"Error processing image: {e}" | |
| if not prompt: prompt = "Convert to markdown." | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=[{"role": "user", "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}} | |
| ]}], | |
| max_tokens=4096 | |
| ) | |
| return completion.choices[0].message.content | |
| except APIConnectionError: | |
| return "⏳ Model is still downloading/loading... Check the 'Logs' tab. This takes 2-3 minutes on a fresh GPU." | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# NuMarkdown L40S vLLM Server") | |
| gr.Markdown("Status: If you just started this Space, wait 3 minutes for weights to download.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Document") | |
| # FIXED: Added the missing prompt input | |
| txt_input = gr.Textbox(value="Convert to markdown.", label="Prompt") | |
| btn = gr.Button("Test Inference") | |
| with gr.Column(): | |
| out = gr.Textbox(label="Output") | |
| # FIXED: Passed both inputs [img_input, txt_input] | |
| btn.click(run_ui_test, inputs=[img_input, txt_input], outputs=[out]) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=EXPOSED_PORT) | |