likhonhfai commited on
Commit
2aabcef
·
verified ·
1 Parent(s): 078f2ed

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. Dockerfile +3 -3
  2. app.py +99 -53
  3. requirements.txt +2 -1
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM nvidia/cuda:12.1.0-base-ubuntu22.04
2
 
3
  # Set up environment
4
  ENV DEBIAN_FRONTEND=noninteractive
@@ -16,7 +16,7 @@ ENV HOME=/home/user \
16
 
17
  WORKDIR $HOME/app
18
 
19
- # Install dependencies
20
  COPY --chown=user requirements.txt .
21
  RUN pip install --no-cache-dir -r requirements.txt
22
 
@@ -27,4 +27,4 @@ COPY --chown=user . .
27
  EXPOSE 7860
28
 
29
  # Run the application
30
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ FROM lmsysorg/sglang:latest
2
 
3
  # Set up environment
4
  ENV DEBIAN_FRONTEND=noninteractive
 
16
 
17
  WORKDIR $HOME/app
18
 
19
+ # Install additional dependencies for our proxy app
20
  COPY --chown=user requirements.txt .
21
  RUN pip install --no-cache-dir -r requirements.txt
22
 
 
27
  EXPOSE 7860
28
 
29
  # Run the application
30
+ CMD ["python3", "app.py"]
app.py CHANGED
@@ -1,40 +1,65 @@
1
  import os
2
- import torch
3
- from fastapi import FastAPI, HTTPException
 
 
4
  from pydantic import BaseModel
5
- from typing import List, Optional, Dict, Any
6
- from vllm import LLM, SamplingParams
7
- from PIL import Image
8
- import base64
9
- from io import BytesIO
10
 
11
- app = FastAPI(title="Fara-7B API")
12
 
13
- # Model configuration
14
  MODEL_ID = "microsoft/Fara-7B"
15
- llm = None
 
 
16
 
17
- def get_llm():
18
- global llm
19
- if llm is None:
20
- # Check for GPU availability
21
- if not torch.cuda.is_available():
22
- # For Spaces, we might want to log this or handle it gracefully
23
- print("CUDA is not available. This model requires a GPU.")
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # vLLM setup
26
- try:
27
- llm = LLM(
28
- model=MODEL_ID,
29
- trust_remote_code=True,
30
- dtype="auto",
31
- max_model_len=4096,
32
- tensor_parallel_size=1
33
- )
34
- except Exception as e:
35
- print(f"Error initializing vLLM: {e}")
36
- raise e
37
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Request models
40
  class Message(BaseModel):
@@ -53,42 +78,63 @@ class MessageRequest(BaseModel):
53
 
54
  @app.get("/")
55
  async def root():
56
- return {"message": "Fara-7B API is running. Use /v1/responses or /v1/messages"}
57
 
58
  @app.get("/health")
59
  async def health():
60
- return {"status": "healthy"}
 
 
 
 
 
 
61
 
62
  @app.post("/v1/responses")
63
  async def generate_response(request: ResponseRequest):
64
  try:
65
- model = get_llm()
66
- sampling_params = SamplingParams(
67
- temperature=request.temperature,
68
- max_tokens=request.max_tokens
69
- )
70
- outputs = model.generate([request.prompt], sampling_params)
71
- return {"response": outputs[0].outputs[0].text}
 
 
 
 
72
  except Exception as e:
73
  raise HTTPException(status_code=500, detail=str(e))
74
 
75
  @app.post("/v1/messages")
76
  async def generate_message(request: MessageRequest):
77
  try:
78
- model = get_llm()
79
- sampling_params = SamplingParams(
80
- temperature=request.temperature,
81
- max_tokens=request.max_tokens
82
- )
83
-
84
- # Formatting for messages
85
- formatted_prompt = ""
86
- for msg in request.messages:
87
- formatted_prompt += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"
88
- formatted_prompt += "<|im_start|>assistant\n"
89
-
90
- outputs = model.generate([formatted_prompt], sampling_params)
91
- return {"message": outputs[0].outputs[0].text}
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  raise HTTPException(status_code=500, detail=str(e))
94
 
 
1
  import os
2
+ import subprocess
3
+ import time
4
+ import requests
5
+ from fastapi import FastAPI, HTTPException, Request
6
  from pydantic import BaseModel
7
+ from typing import List, Optional, Any
8
+ import torch
 
 
 
9
 
10
+ app = FastAPI(title="Fara-7B SGLang API")
11
 
12
+ # Configuration
13
  MODEL_ID = "microsoft/Fara-7B"
14
+ SGLANG_PORT = 30000
15
+ SGLANG_HOST = "127.0.0.1"
16
+ SGLANG_URL = f"http://{SGLANG_HOST}:{SGLANG_PORT}"
17
 
18
+ # Global process for SGLang server
19
+ sglang_process = None
20
+
21
+ def start_sglang():
22
+ global sglang_process
23
+ if sglang_process is None:
24
+ print(f"Starting SGLang server for {MODEL_ID}...")
25
+
26
+ # Command to start SGLang server
27
+ # Using --chat-template qwen2-vl as Fara-7B is based on Qwen2.5-VL
28
+ cmd = [
29
+ "python3", "-m", "sglang.launch_server",
30
+ "--model-path", MODEL_ID,
31
+ "--host", SGLANG_HOST,
32
+ "--port", str(SGLANG_PORT),
33
+ "--chat-template", "qwen2-vl",
34
+ "--trust-remote-code"
35
+ ]
36
 
37
+ # Check GPU availability for tensor parallel
38
+ if torch.cuda.device_count() > 1:
39
+ cmd.extend(["--tp", str(torch.cuda.device_count())])
40
+
41
+ sglang_process = subprocess.Popen(cmd)
42
+
43
+ # Wait for server to be ready
44
+ max_retries = 60
45
+ for i in range(max_retries):
46
+ try:
47
+ response = requests.get(f"{SGLANG_URL}/v1/models")
48
+ if response.status_code == 200:
49
+ print("SGLang server is ready!")
50
+ return
51
+ except:
52
+ pass
53
+ print(f"Waiting for SGLang server... ({i+1}/{max_retries})")
54
+ time.sleep(10)
55
+
56
+ raise RuntimeError("SGLang server failed to start within timeout.")
57
+
58
+ @app.on_event("startup")
59
+ async def startup_event():
60
+ # Start SGLang in the background
61
+ import threading
62
+ threading.Thread(target=start_sglang, daemon=True).start()
63
 
64
  # Request models
65
  class Message(BaseModel):
 
78
 
79
  @app.get("/")
80
  async def root():
81
+ return {"message": "Fara-7B SGLang API is running. Use /v1/responses or /v1/messages"}
82
 
83
  @app.get("/health")
84
  async def health():
85
+ try:
86
+ resp = requests.get(f"{SGLANG_URL}/v1/models", timeout=2)
87
+ if resp.status_code == 200:
88
+ return {"status": "healthy", "backend": "sglang"}
89
+ except:
90
+ pass
91
+ return {"status": "starting", "backend": "sglang"}
92
 
93
  @app.post("/v1/responses")
94
  async def generate_response(request: ResponseRequest):
95
  try:
96
+ # Map /v1/responses to SGLang's completions or chat completions
97
+ payload = {
98
+ "model": MODEL_ID,
99
+ "prompt": request.prompt,
100
+ "max_tokens": request.max_tokens,
101
+ "temperature": request.temperature
102
+ }
103
+ resp = requests.post(f"{SGLANG_URL}/v1/completions", json=payload)
104
+ resp.raise_for_status()
105
+ data = resp.json()
106
+ return {"response": data["choices"][0]["text"]}
107
  except Exception as e:
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
  @app.post("/v1/messages")
111
  async def generate_message(request: MessageRequest):
112
  try:
113
+ # Map /v1/messages to SGLang's chat completions
114
+ payload = {
115
+ "model": MODEL_ID,
116
+ "messages": [m.dict() for m in request.messages],
117
+ "max_tokens": request.max_tokens,
118
+ "temperature": request.temperature
119
+ }
120
+ resp = requests.post(f"{SGLANG_URL}/v1/chat/completions", json=payload)
121
+ resp.raise_for_status()
122
+ data = resp.json()
123
+ return {"message": data["choices"][0]["message"]["content"]}
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
127
+ # Proxy other OpenAI compatible requests to SGLang if needed
128
+ @app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
129
+ async def proxy_openai(path: str, request: Request):
130
+ url = f"{SGLANG_URL}/v1/{path}"
131
+ method = request.method
132
+ headers = {k: v for k, v in request.headers.items() if k.lower() != "host"}
133
+ body = await request.body()
134
+
135
+ try:
136
+ resp = requests.request(method, url, headers=headers, data=body, timeout=300)
137
+ return resp.json()
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=str(e))
140
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  fastapi
2
  uvicorn
3
- vllm
4
  huggingface_hub
5
  python-multipart
6
  pydantic
@@ -8,3 +8,4 @@ pillow
8
  torch
9
  transformers
10
  accelerate
 
 
1
  fastapi
2
  uvicorn
3
+ sglang
4
  huggingface_hub
5
  python-multipart
6
  pydantic
 
8
  torch
9
  transformers
10
  accelerate
11
+ requests