jeanbaptdzd commited on
Commit
9c71bb7
·
1 Parent(s): dc80161

Migrate from vLLM to Transformers library

Browse files

- Removed vLLM dependency (doesn't support Qwen3ForCausalLM yet)
- Switched to Transformers library with native Qwen3 support
- Updated Dockerfile: removed vLLM, added transformers + accelerate
- Rewrote app/providers/vllm.py to use Transformers
- Implemented streaming with TextIteratorStreamer
- Updated all documentation and configuration
- Removed vllm_base_url from config
- Updated tests to match new config structure

This provides better compatibility with Qwen3 models while we wait for vLLM support.

Files changed (7) hide show
  1. Dockerfile +11 -11
  2. README.md +4 -8
  3. app/config.py +0 -1
  4. app/main.py +4 -4
  5. app/providers/vllm.py +131 -154
  6. requirements.txt +1 -3
  7. tests/test_config.py +1 -4
Dockerfile CHANGED
@@ -24,15 +24,18 @@ RUN python3 -m pip install --upgrade pip
24
  # Set working directory
25
  WORKDIR /app
26
 
27
- # Install PyTorch with CUDA 12.4 support FIRST (critical for vLLM compatibility)
28
- # Updated to PyTorch 2.5+ for better vLLM 0.9.x compatibility
29
  RUN pip install --no-cache-dir \
30
  torch>=2.5.0 \
 
 
31
  --index-url https://download.pytorch.org/whl/cu124
32
 
33
- # Install vLLM 0.11.0 (latest, supports Qwen3ForCausalLM - requires 0.8.4+)
34
- # vLLM 0.11.0 - includes Qwen3 support and latest optimizations
35
- RUN pip install --no-cache-dir vllm==0.11.0
 
 
36
 
37
  # Install application dependencies
38
  RUN pip install --no-cache-dir \
@@ -56,17 +59,14 @@ RUN useradd -m -u 1000 user && \
56
 
57
  USER user
58
 
59
- # Set environment variables for optimal vLLM performance
60
  ENV HF_HOME=/tmp/huggingface
61
  ENV TORCHINDUCTOR_CACHE_DIR=/tmp/torch/inductor
62
- ENV TRITON_CACHE_DIR=/tmp/triton
63
- ENV TORCH_COMPILE_DEBUG=0
64
  ENV CUDA_VISIBLE_DEVICES=0
65
  # Optimize CUDA memory allocation
66
  ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
67
- # vLLM 0.9.x uses v1 engine by default (more efficient)
68
- # VLLM_USE_V1=0 can be set if needed for compatibility, but v1 is recommended
69
- # ENV VLLM_USE_V1=0 # Commented out - v1 engine is default and preferred in 0.9.x
70
 
71
  # Expose port
72
  EXPOSE 7860
 
24
  # Set working directory
25
  WORKDIR /app
26
 
27
+ # Install PyTorch with CUDA 12.4 support
 
28
  RUN pip install --no-cache-dir \
29
  torch>=2.5.0 \
30
+ torchvision \
31
+ torchaudio \
32
  --index-url https://download.pytorch.org/whl/cu124
33
 
34
+ # Install Transformers and accelerate for optimized inference
35
+ RUN pip install --no-cache-dir \
36
+ transformers>=4.40.0 \
37
+ accelerate>=0.30.0 \
38
+ bitsandbytes # Optional: for quantization support
39
 
40
  # Install application dependencies
41
  RUN pip install --no-cache-dir \
 
59
 
60
  USER user
61
 
62
+ # Set environment variables for optimal Transformers performance
63
  ENV HF_HOME=/tmp/huggingface
64
  ENV TORCHINDUCTOR_CACHE_DIR=/tmp/torch/inductor
 
 
65
  ENV CUDA_VISIBLE_DEVICES=0
66
  # Optimize CUDA memory allocation
67
  ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
68
+ # Enable Transformers optimizations
69
+ ENV TRANSFORMERS_CACHE=/tmp/huggingface
 
70
 
71
  # Expose port
72
  EXPOSE 7860
README.md CHANGED
@@ -11,7 +11,7 @@ suggested_hardware: l4x1
11
 
12
  # Open Finance LLM 8B
13
 
14
- OpenAI-compatible API powered by `DragonLLM/qwen3-8b-fin-v1.0` via vLLM.
15
 
16
  ## 🚀 Quick Start
17
 
@@ -63,14 +63,9 @@ The service uses these environment variables:
63
  - **Important**: You must accept the model's terms at https://huggingface.co/DragonLLM/qwen3-8b-fin-v1.0 before the token will work
64
 
65
  ### Optional Configuration
66
- - `VLLM_BASE_URL`: vLLM server endpoint (default: `http://localhost:8000/v1`)
67
  - `MODEL`: Model name (default: `DragonLLM/qwen3-8b-fin-v1.0`)
68
  - `SERVICE_API_KEY`: Optional API key for authentication (set via `x-api-key` header)
69
  - `LOG_LEVEL`: Logging level (default: `info`)
70
- - `VLLM_USE_EAGER`: Control optimization mode (default: `auto`)
71
- - `auto`: Try optimized mode (CUDA graphs), fallback to eager if needed (recommended)
72
- - `false`: Force optimized mode (CUDA graphs enabled, may fail if unsupported)
73
- - `true`: Force eager mode (slower but more stable)
74
 
75
  ### Setting Up HF_TOKEN_LC2 in Hugging Face Spaces
76
 
@@ -145,9 +140,10 @@ MIT License - see LICENSE file for details.
145
 
146
  ---
147
 
148
- **Note**: This service runs vLLM 0.11.0 (latest stable) with `DragonLLM/qwen3-8b-fin-v1.0` model. The service initializes the model automatically on startup. For production use, ensure proper GPU resources (L4 or better) are available.
149
 
150
  ### Version Information
151
- - **vLLM:** 0.11.0 (supports Qwen3ForCausalLM - requires 0.8.4+)
152
  - **PyTorch:** 2.5.0+ (CUDA 12.4)
153
  - **CUDA:** 12.4
 
 
11
 
12
  # Open Finance LLM 8B
13
 
14
+ OpenAI-compatible API powered by `DragonLLM/qwen3-8b-fin-v1.0` via Transformers.
15
 
16
  ## 🚀 Quick Start
17
 
 
63
  - **Important**: You must accept the model's terms at https://huggingface.co/DragonLLM/qwen3-8b-fin-v1.0 before the token will work
64
 
65
  ### Optional Configuration
 
66
  - `MODEL`: Model name (default: `DragonLLM/qwen3-8b-fin-v1.0`)
67
  - `SERVICE_API_KEY`: Optional API key for authentication (set via `x-api-key` header)
68
  - `LOG_LEVEL`: Logging level (default: `info`)
 
 
 
 
69
 
70
  ### Setting Up HF_TOKEN_LC2 in Hugging Face Spaces
71
 
 
140
 
141
  ---
142
 
143
+ **Note**: This service runs with `DragonLLM/qwen3-8b-fin-v1.0` using the Transformers library. The service initializes the model automatically on startup. For production use, ensure proper GPU resources (L4 or better) are available.
144
 
145
  ### Version Information
146
+ - **Transformers:** 4.40.0+ (supports Qwen3ForCausalLM)
147
  - **PyTorch:** 2.5.0+ (CUDA 12.4)
148
  - **CUDA:** 12.4
149
+ - **Accelerate:** 0.30.0+ (for optimized inference)
app/config.py CHANGED
@@ -2,7 +2,6 @@ from pydantic_settings import BaseSettings
2
 
3
 
4
  class Settings(BaseSettings):
5
- vllm_base_url: str = "http://localhost:8000/v1"
6
  model: str = "DragonLLM/qwen3-8b-fin-v1.0"
7
  service_api_key: str | None = None
8
  log_level: str = "info"
 
2
 
3
 
4
  class Settings(BaseSettings):
 
5
  model: str = "DragonLLM/qwen3-8b-fin-v1.0"
6
  service_api_key: str | None = None
7
  log_level: str = "info"
app/main.py CHANGED
@@ -7,7 +7,7 @@ import logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
- app = FastAPI(title="LLM Pro Finance API (vLLM)")
11
 
12
  # Mount routers
13
  app.include_router(openai_api.router, prefix="/v1")
@@ -23,8 +23,8 @@ async def startup_event():
23
  logger.info("Initializing model in background thread...")
24
 
25
  def load_model():
26
- from app.providers.vllm import initialize_vllm
27
- initialize_vllm()
28
 
29
  # Start model loading in background thread
30
  thread = threading.Thread(target=load_model, daemon=True)
@@ -38,7 +38,7 @@ async def root():
38
  "service": "Qwen Open Finance R 8B Inference",
39
  "version": "1.0.0",
40
  "model": "DragonLLM/qwen3-8b-fin-v1.0",
41
- "backend": "vLLM"
42
  }
43
 
44
  @app.get("/health")
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
+ app = FastAPI(title="LLM Pro Finance API (Transformers)")
11
 
12
  # Mount routers
13
  app.include_router(openai_api.router, prefix="/v1")
 
23
  logger.info("Initializing model in background thread...")
24
 
25
  def load_model():
26
+ from app.providers.vllm import initialize_model
27
+ initialize_model()
28
 
29
  # Start model loading in background thread
30
  thread = threading.Thread(target=load_model, daemon=True)
 
38
  "service": "Qwen Open Finance R 8B Inference",
39
  "version": "1.0.0",
40
  "model": "DragonLLM/qwen3-8b-fin-v1.0",
41
+ "backend": "Transformers"
42
  }
43
 
44
  @app.get("/health")
app/providers/vllm.py CHANGED
@@ -1,28 +1,32 @@
1
  import os
2
  import time
 
3
  from typing import Dict, Any, AsyncIterator, Union
4
- from vllm import LLM, SamplingParams
5
  import asyncio
6
  from huggingface_hub import login
 
 
7
 
8
- # Model configuration - back to working DragonLLM model
9
  model_name = "DragonLLM/qwen3-8b-fin-v1.0"
10
- llm_engine = None
 
 
11
 
12
- def initialize_vllm():
13
- """Initialize vLLM engine with the model
14
 
15
  Handles authentication with Hugging Face Hub for accessing DragonLLM models.
16
  Prioritizes HF_TOKEN_LC2 (DragonLLM access) over HF_TOKEN_LC.
17
  """
18
- global llm_engine
19
 
20
- if llm_engine is None:
21
  import logging
22
  logger = logging.getLogger(__name__)
23
 
24
- logger.info(f"Initializing vLLM with model: {model_name}")
25
- print(f"Initializing vLLM with model: {model_name}")
26
 
27
  # Get HF token from environment (Hugging Face Space secret)
28
  # Priority: HF_TOKEN_LC2 (for DragonLLM access) > HF_TOKEN_LC > HF_TOKEN
@@ -56,99 +60,55 @@ def initialize_vllm():
56
  logger.warning(f"⚠️ Warning: Failed to authenticate with HF Hub: {e}")
57
  print(f"⚠️ Warning: Failed to authenticate with HF Hub: {e}")
58
 
59
- # Set all possible environment variables that vLLM/huggingface_hub might check
60
- # This ensures compatibility across different versions
61
  os.environ["HF_TOKEN"] = hf_token
62
  os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
63
- # Some tools check for these variants too
64
  os.environ["HF_API_TOKEN"] = hf_token
65
 
66
  logger.info("✅ Hugging Face token environment variables set")
67
  else:
68
  logger.warning("⚠️ WARNING: No HF token found in environment!")
69
- logger.warning(f" Checked: HF_TOKEN_LC2, HF_TOKEN_LC, HF_TOKEN, HUGGING_FACE_HUB_TOKEN")
70
- logger.warning(f" Available env vars: {[k for k in os.environ.keys() if 'TOKEN' in k or 'HF' in k]}")
71
  print("⚠️ WARNING: No HF token found in environment!")
72
  print(f" Checked: HF_TOKEN_LC2, HF_TOKEN_LC, HF_TOKEN, HUGGING_FACE_HUB_TOKEN")
73
- print(f" Available env vars with 'TOKEN' or 'HF': {[k for k in os.environ.keys() if 'TOKEN' in k or 'HF' in k]}")
74
  print(" ⚠️ Model download may fail if DragonLLM/qwen3-8b-fin-v1.0 is gated!")
75
 
76
  try:
77
- # Initialize vLLM engine
78
- # Note: vLLM 0.11.0 supports Qwen3ForCausalLM (requires 0.8.4+)
79
- logger.info(f"Attempting to load model: {model_name}")
80
- print(f"Attempting to load model: {model_name}")
81
- print(f"Model type: DragonLLM Qwen3 8B (bfloat16)")
82
- print(f"vLLM version: 0.11.0 (Qwen3ForCausalLM support)")
83
- print(f"Download directory: /tmp/huggingface")
84
  print(f"Trust remote code: True")
85
- print(f"L4 GPU: 24GB VRAM available")
86
 
87
- # Try optimized mode first (CUDA graphs enabled)
88
- # Falls back to eager mode if CUDA graphs fail
89
- use_optimized = os.getenv("VLLM_USE_EAGER", "auto").lower()
90
- if use_optimized == "true":
91
- enforce_eager = True
92
- mode_desc = "Eager mode (forced)"
93
- elif use_optimized == "false":
94
- enforce_eager = False
95
- mode_desc = "Optimized mode (CUDA graphs enabled)"
96
- else: # "auto" - try optimized, fallback to eager
97
- enforce_eager = False
98
- mode_desc = "Optimized mode (auto, fallback to eager if needed)"
99
 
100
- print(f"Mode: {mode_desc}")
101
- print(f"GPU memory utilization: 0.85")
102
- print(f"vLLM: v0.9.2 (Latest stable, improved Qwen3 support)")
103
- print(f"PyTorch: 2.5.0+ (CUDA 12.4 binary)")
 
 
 
 
 
 
104
 
105
- # Common initialization parameters
106
- init_params = {
107
- "model": model_name,
108
- "trust_remote_code": True,
109
- "dtype": "bfloat16", # Use bfloat16 for Qwen3 (required)
110
- "max_model_len": 4096, # Reduced for L4 KV cache constraints
111
- "gpu_memory_utilization": 0.85, # Can use more with stable v0 engine
112
- "tensor_parallel_size": 1, # Single L4 GPU
113
- "download_dir": "/tmp/huggingface",
114
- "tokenizer_mode": "auto",
115
- "disable_log_stats": False, # Enable logging for debugging
116
- }
117
 
118
- # Try optimized mode first (unless explicitly disabled)
119
- if use_optimized == "auto" or use_optimized == "false":
120
- try:
121
- print(f"🚀 Attempting optimized mode with CUDA graphs...")
122
- logger.info("Attempting optimized mode (enforce_eager=False)")
123
- init_params["enforce_eager"] = False
124
- llm_engine = LLM(**init_params)
125
- print(f"✅ vLLM engine initialized successfully in OPTIMIZED mode!")
126
- logger.info("✅ vLLM engine initialized in optimized mode (CUDA graphs enabled)")
127
- except Exception as opt_error:
128
- error_msg = str(opt_error).lower()
129
- # Check if error is CUDA graph related
130
- if "cuda graph" in error_msg or "graph" in error_msg or use_optimized == "auto":
131
- logger.warning(f"⚠️ Optimized mode failed, falling back to eager mode: {opt_error}")
132
- print(f"⚠️ Optimized mode failed: {opt_error}")
133
- print(f"🔄 Falling back to eager mode for stability...")
134
- init_params["enforce_eager"] = True
135
- llm_engine = LLM(**init_params)
136
- print(f"✅ vLLM engine initialized successfully in EAGER mode (fallback)")
137
- logger.info("✅ vLLM engine initialized in eager mode (fallback after optimized mode failure)")
138
- else:
139
- # Re-raise if it's not a CUDA graph issue or if optimized is forced
140
- raise
141
- else:
142
- # Eager mode explicitly requested
143
- print(f"⚙️ Using eager mode (explicitly requested)")
144
- logger.info("Using eager mode (VLLM_USE_EAGER=true)")
145
- init_params["enforce_eager"] = True
146
- llm_engine = LLM(**init_params)
147
- print(f"✅ vLLM engine initialized successfully in EAGER mode!")
148
- logger.info("✅ vLLM engine initialized in eager mode")
149
 
150
  except Exception as e:
151
- error_msg = f"❌ Error initializing vLLM: {e}"
152
  logger.error(error_msg, exc_info=True)
153
  print(error_msg)
154
 
@@ -167,7 +127,7 @@ def initialize_vllm():
167
  raise
168
 
169
 
170
- class VLLMProvider:
171
  def __init__(self):
172
  # Don't initialize at import time
173
  pass
@@ -193,44 +153,61 @@ class VLLMProvider:
193
  logger = logging.getLogger(__name__)
194
 
195
  try:
196
- # Initialize vLLM on first use
197
- if llm_engine is None:
198
- logger.info("vLLM engine not initialized, initializing now...")
199
- initialize_vllm()
200
- logger.info("vLLM engine initialized successfully")
201
 
202
  messages = payload.get("messages", [])
203
  temperature = payload.get("temperature", 0.7)
204
  max_tokens = payload.get("max_tokens", 1000)
205
  top_p = payload.get("top_p", 1.0)
206
 
207
- # Convert messages to prompt
208
- prompt = self._messages_to_prompt(messages)
 
 
 
 
 
 
 
 
 
209
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
210
 
211
- # Set up sampling parameters
212
- sampling_params = SamplingParams(
213
- temperature=temperature,
214
- top_p=top_p,
215
- max_tokens=max_tokens,
216
- )
217
 
218
  # Handle streaming vs non-streaming
219
  if stream:
220
- return self._chat_stream(prompt, sampling_params, payload.get("model", model_name))
221
 
222
- # Generate response using vLLM (non-streaming)
223
- outputs = llm_engine.generate([prompt], sampling_params)
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- # Extract the generated text
226
- generated_text = outputs[0].outputs[0].text
227
  logger.info(f"Generated text: {generated_text[:100]}...")
228
 
 
 
 
 
229
  # Build OpenAI-compatible response
230
  completion_id = f"chatcmpl-{os.urandom(12).hex()}"
231
  created = int(time.time())
232
- prompt_tokens = len(outputs[0].prompt_token_ids)
233
- completion_tokens = len(outputs[0].outputs[0].token_ids)
234
 
235
  return {
236
  "id": completion_id,
@@ -257,72 +234,72 @@ class VLLMProvider:
257
  logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
258
  raise
259
 
260
- async def _chat_stream(self, prompt: str, sampling_params: SamplingParams, model: str) -> AsyncIterator[str]:
261
- """Stream chat completions using vLLM
262
-
263
- Note: vLLM 0.6.5 with synchronous LLM doesn't support true streaming.
264
- This implementation generates the full response and yields it in chunks
265
- for OpenAI API compatibility. For true streaming, use AsyncLLMEngine.
266
- """
267
  import logging
268
  logger = logging.getLogger(__name__)
269
 
270
  completion_id = f"chatcmpl-{os.urandom(12).hex()}"
271
  created = int(time.time())
272
 
273
- # Generate response (non-streaming backend, but we'll chunk it)
274
- # Run in thread pool to avoid blocking
275
- loop = asyncio.get_event_loop()
276
- outputs = await loop.run_in_executor(
277
- None,
278
- lambda: llm_engine.generate([prompt], sampling_params)
279
- )
280
 
281
- generated_text = outputs[0].outputs[0].text
282
- finish_reason = outputs[0].outputs[0].finish_reason or "stop"
 
 
 
 
 
 
 
283
 
284
- # Yield text in chunks (simulate streaming)
285
- # Split into reasonable chunks (words or characters)
286
- chunk_size = 10 # words per chunk
287
- words = generated_text.split()
288
 
289
- for i in range(0, len(words), chunk_size):
290
- chunk_words = words[i:i + chunk_size]
291
- delta_text = " ".join(chunk_words)
292
- if i + chunk_size < len(words):
293
- delta_text += " "
294
-
295
- # Format as OpenAI SSE stream chunk
296
- chunk = {
297
- "id": completion_id,
298
- "object": "chat.completion.chunk",
299
- "created": created,
300
- "model": model,
301
- "choices": [
302
- {
303
- "index": 0,
304
- "delta": {
305
- "content": delta_text
306
- },
307
- "finish_reason": None
308
- }
309
- ]
310
- }
311
-
312
- yield f"data: {self._json_dumps(chunk)}\n\n"
313
- await asyncio.sleep(0) # Yield control
 
 
 
314
 
315
- # Send final chunk with finish_reason
316
  final_chunk = {
317
  "id": completion_id,
318
  "object": "chat.completion.chunk",
319
  "created": created,
320
- "model": model,
321
  "choices": [
322
  {
323
  "index": 0,
324
  "delta": {},
325
- "finish_reason": finish_reason
326
  }
327
  ]
328
  }
@@ -335,7 +312,7 @@ class VLLMProvider:
335
  return json.dumps(obj, ensure_ascii=False)
336
 
337
  def _messages_to_prompt(self, messages: list) -> str:
338
- """Convert OpenAI messages format to prompt"""
339
  prompt = ""
340
  for message in messages:
341
  role = message["role"]
@@ -351,7 +328,7 @@ class VLLMProvider:
351
 
352
 
353
  # Module-level provider instance for backward compatibility
354
- _provider = VLLMProvider()
355
 
356
 
357
  # Module-level functions for direct import
 
1
  import os
2
  import time
3
+ import torch
4
  from typing import Dict, Any, AsyncIterator, Union
 
5
  import asyncio
6
  from huggingface_hub import login
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
+ from threading import Thread
9
 
10
+ # Model configuration
11
  model_name = "DragonLLM/qwen3-8b-fin-v1.0"
12
+ model = None
13
+ tokenizer = None
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ def initialize_model():
17
+ """Initialize Transformers model with Qwen3
18
 
19
  Handles authentication with Hugging Face Hub for accessing DragonLLM models.
20
  Prioritizes HF_TOKEN_LC2 (DragonLLM access) over HF_TOKEN_LC.
21
  """
22
+ global model, tokenizer
23
 
24
+ if model is None:
25
  import logging
26
  logger = logging.getLogger(__name__)
27
 
28
+ logger.info(f"Initializing Transformers with model: {model_name}")
29
+ print(f"Initializing Transformers with model: {model_name}")
30
 
31
  # Get HF token from environment (Hugging Face Space secret)
32
  # Priority: HF_TOKEN_LC2 (for DragonLLM access) > HF_TOKEN_LC > HF_TOKEN
 
60
  logger.warning(f"⚠️ Warning: Failed to authenticate with HF Hub: {e}")
61
  print(f"⚠️ Warning: Failed to authenticate with HF Hub: {e}")
62
 
63
+ # Set all possible environment variables
 
64
  os.environ["HF_TOKEN"] = hf_token
65
  os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
 
66
  os.environ["HF_API_TOKEN"] = hf_token
67
 
68
  logger.info("✅ Hugging Face token environment variables set")
69
  else:
70
  logger.warning("⚠️ WARNING: No HF token found in environment!")
 
 
71
  print("⚠️ WARNING: No HF token found in environment!")
72
  print(f" Checked: HF_TOKEN_LC2, HF_TOKEN_LC, HF_TOKEN, HUGGING_FACE_HUB_TOKEN")
 
73
  print(" ⚠️ Model download may fail if DragonLLM/qwen3-8b-fin-v1.0 is gated!")
74
 
75
  try:
76
+ logger.info(f"Loading model: {model_name}")
77
+ print(f"Loading model: {model_name}")
78
+ print(f"Model type: DragonLLM Qwen3 8B")
79
+ print(f"Device: {device}")
 
 
 
80
  print(f"Trust remote code: True")
 
81
 
82
+ # Load tokenizer
83
+ print("📥 Loading tokenizer...")
84
+ tokenizer = AutoTokenizer.from_pretrained(
85
+ model_name,
86
+ token=hf_token,
87
+ trust_remote_code=True,
88
+ cache_dir="/tmp/huggingface"
89
+ )
90
+ logger.info("✅ Tokenizer loaded")
91
+ print(" Tokenizer loaded")
 
 
92
 
93
+ # Load model with optimizations
94
+ print("📥 Loading model (this may take a few minutes)...")
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ model_name,
97
+ token=hf_token,
98
+ trust_remote_code=True,
99
+ torch_dtype=torch.bfloat16,
100
+ device_map="auto",
101
+ cache_dir="/tmp/huggingface"
102
+ )
103
 
104
+ # Set to eval mode for inference
105
+ model.eval()
 
 
 
 
 
 
 
 
 
 
106
 
107
+ print(f"✅ Model loaded successfully!")
108
+ logger.info(" Model initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  except Exception as e:
111
+ error_msg = f"❌ Error initializing model: {e}"
112
  logger.error(error_msg, exc_info=True)
113
  print(error_msg)
114
 
 
127
  raise
128
 
129
 
130
+ class TransformersProvider:
131
  def __init__(self):
132
  # Don't initialize at import time
133
  pass
 
153
  logger = logging.getLogger(__name__)
154
 
155
  try:
156
+ # Initialize model on first use
157
+ if model is None:
158
+ logger.info("Model not initialized, initializing now...")
159
+ initialize_model()
160
+ logger.info("Model initialized successfully")
161
 
162
  messages = payload.get("messages", [])
163
  temperature = payload.get("temperature", 0.7)
164
  max_tokens = payload.get("max_tokens", 1000)
165
  top_p = payload.get("top_p", 1.0)
166
 
167
+ # Convert messages to prompt using tokenizer's chat template
168
+ if hasattr(tokenizer, "apply_chat_template"):
169
+ prompt = tokenizer.apply_chat_template(
170
+ messages,
171
+ tokenize=False,
172
+ add_generation_prompt=True
173
+ )
174
+ else:
175
+ # Fallback to simple prompt format
176
+ prompt = self._messages_to_prompt(messages)
177
+
178
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
179
 
180
+ # Tokenize
181
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
182
 
183
  # Handle streaming vs non-streaming
184
  if stream:
185
+ return self._chat_stream(inputs, temperature, top_p, max_tokens, payload.get("model", model_name))
186
 
187
+ # Generate response (non-streaming)
188
+ with torch.no_grad():
189
+ outputs = model.generate(
190
+ **inputs,
191
+ max_new_tokens=max_tokens,
192
+ temperature=temperature,
193
+ top_p=top_p,
194
+ do_sample=temperature > 0,
195
+ pad_token_id=tokenizer.eos_token_id
196
+ )
197
+
198
+ # Decode response
199
+ generated_ids = outputs[0][inputs.input_ids.shape[1]:]
200
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
201
 
 
 
202
  logger.info(f"Generated text: {generated_text[:100]}...")
203
 
204
+ # Calculate tokens (approximate)
205
+ prompt_tokens = inputs.input_ids.shape[1]
206
+ completion_tokens = len(generated_ids)
207
+
208
  # Build OpenAI-compatible response
209
  completion_id = f"chatcmpl-{os.urandom(12).hex()}"
210
  created = int(time.time())
 
 
211
 
212
  return {
213
  "id": completion_id,
 
234
  logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
235
  raise
236
 
237
+ async def _chat_stream(self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str) -> AsyncIterator[str]:
238
+ """Stream chat completions using Transformers TextIteratorStreamer"""
 
 
 
 
 
239
  import logging
240
  logger = logging.getLogger(__name__)
241
 
242
  completion_id = f"chatcmpl-{os.urandom(12).hex()}"
243
  created = int(time.time())
244
 
245
+ # Create streamer
246
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
247
 
248
+ # Generation parameters
249
+ generation_kwargs = {
250
+ "max_new_tokens": max_tokens,
251
+ "temperature": temperature,
252
+ "top_p": top_p,
253
+ "do_sample": temperature > 0,
254
+ "pad_token_id": tokenizer.eos_token_id,
255
+ "streamer": streamer
256
+ }
257
 
258
+ # Run generation in a separate thread
259
+ def generate():
260
+ with torch.no_grad():
261
+ model.generate(**inputs, **generation_kwargs)
262
 
263
+ generation_thread = Thread(target=generate)
264
+ generation_thread.start()
265
+
266
+ # Stream tokens as they're generated
267
+ try:
268
+ for token in streamer:
269
+ # Yield chunks
270
+ chunk = {
271
+ "id": completion_id,
272
+ "object": "chat.completion.chunk",
273
+ "created": created,
274
+ "model": model_id,
275
+ "choices": [
276
+ {
277
+ "index": 0,
278
+ "delta": {
279
+ "content": token
280
+ },
281
+ "finish_reason": None
282
+ }
283
+ ]
284
+ }
285
+
286
+ yield f"data: {self._json_dumps(chunk)}\n\n"
287
+ await asyncio.sleep(0) # Yield control
288
+ finally:
289
+ # Wait for generation to complete
290
+ generation_thread.join()
291
 
292
+ # Send final chunk
293
  final_chunk = {
294
  "id": completion_id,
295
  "object": "chat.completion.chunk",
296
  "created": created,
297
+ "model": model_id,
298
  "choices": [
299
  {
300
  "index": 0,
301
  "delta": {},
302
+ "finish_reason": "stop"
303
  }
304
  ]
305
  }
 
312
  return json.dumps(obj, ensure_ascii=False)
313
 
314
  def _messages_to_prompt(self, messages: list) -> str:
315
+ """Convert OpenAI messages format to prompt (fallback)"""
316
  prompt = ""
317
  for message in messages:
318
  role = message["role"]
 
328
 
329
 
330
  # Module-level provider instance for backward compatibility
331
+ _provider = TransformersProvider()
332
 
333
 
334
  # Module-level functions for direct import
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
  # Core dependencies for OpenAI-compatible API service
2
- # Note: vLLM and PyTorch are installed separately in Dockerfile for CUDA support
3
- # vllm==0.6.5 # Installed in Dockerfile
4
- # torch==2.4.0 # Installed in Dockerfile
5
 
6
  fastapi>=0.115.0
7
  uvicorn[standard]>=0.30.0
 
1
  # Core dependencies for OpenAI-compatible API service
2
+ # Note: PyTorch and Transformers are installed separately in Dockerfile for CUDA support
 
 
3
 
4
  fastapi>=0.115.0
5
  uvicorn[standard]>=0.30.0
tests/test_config.py CHANGED
@@ -9,7 +9,6 @@ from app.config import Settings
9
  def test_settings_defaults():
10
  """Test that settings have correct default values."""
11
  settings = Settings()
12
- assert settings.vllm_base_url == "http://localhost:8000/v1"
13
  assert settings.model == "DragonLLM/qwen3-8b-fin-v1.0"
14
  assert settings.service_api_key is None
15
  assert settings.log_level == "info"
@@ -18,13 +17,11 @@ def test_settings_defaults():
18
  def test_settings_from_env():
19
  """Test that settings can be loaded from environment variables."""
20
  with patch.dict(os.environ, {
21
- "VLLM_BASE_URL": "http://remote:8000/v1",
22
  "MODEL": "custom-model",
23
  "SERVICE_API_KEY": "secret-key",
24
  "LOG_LEVEL": "debug"
25
  }):
26
  settings = Settings()
27
- assert settings.vllm_base_url == "http://remote:8000/v1"
28
  assert settings.model == "custom-model"
29
  assert settings.service_api_key == "secret-key"
30
  assert settings.log_level == "debug"
@@ -36,4 +33,4 @@ def test_settings_env_file():
36
  # In practice, you'd create a test .env file or mock the file reading
37
  settings = Settings()
38
  # Verify that the settings object can be instantiated
39
- assert isinstance(settings.vllm_base_url, str)
 
9
  def test_settings_defaults():
10
  """Test that settings have correct default values."""
11
  settings = Settings()
 
12
  assert settings.model == "DragonLLM/qwen3-8b-fin-v1.0"
13
  assert settings.service_api_key is None
14
  assert settings.log_level == "info"
 
17
  def test_settings_from_env():
18
  """Test that settings can be loaded from environment variables."""
19
  with patch.dict(os.environ, {
 
20
  "MODEL": "custom-model",
21
  "SERVICE_API_KEY": "secret-key",
22
  "LOG_LEVEL": "debug"
23
  }):
24
  settings = Settings()
 
25
  assert settings.model == "custom-model"
26
  assert settings.service_api_key == "secret-key"
27
  assert settings.log_level == "debug"
 
33
  # In practice, you'd create a test .env file or mock the file reading
34
  settings = Settings()
35
  # Verify that the settings object can be instantiated
36
+ assert isinstance(settings.model, str)