SmokeScan / config /inference.py
KinetoLabs's picture
Replace dual 8B with single 30B-A3B FP8 vision model
706520f
"""Model inference configuration parameters.
Configuration values aligned with official Qwen3-VL model recommendations
and FDAM Technical Spec requirements.
Pipeline uses:
- Vision: Qwen/Qwen3-VL-30B-A3B-Thinking-FP8 (single model, FP8 via vLLM)
- Embedding: Qwen/Qwen3-VL-Embedding-2B (2048-dim)
- Reranker: Qwen/Qwen3-VL-Reranker-2B
"""
from dataclasses import dataclass
@dataclass
class VisionInferenceConfig:
"""Configuration for 30B-A3B FP8 vision model inference.
Single model handles both analysis and structured JSON output.
Uses vLLM with tensor parallelism across 4 GPUs.
"""
max_tokens: int = 8192 # vLLM uses max_tokens not max_new_tokens
temperature: float = 0.6 # Per Qwen3-VL GitHub docs
top_p: float = 0.95
top_k: int = 20
repetition_penalty: float = 1.0 # Per Qwen3-VL docs
@dataclass
class GenerationInferenceConfig:
"""Configuration for document generation (SOW, sampling plans).
Per FDAM Technical Spec Section 3 - separate config for longer generation.
"""
max_new_tokens: int = 8192
temperature: float = 0.2 # Slightly higher for more varied text
top_p: float = 0.95
do_sample: bool = True
repetition_penalty: float = 1.05
@dataclass
class EmbeddingConfig:
"""Configuration for embedding model.
Per Qwen3-VL-Embedding-2B config.json: text_config.hidden_size = 2048
"""
embedding_dimension: int = 2048 # Per Qwen3-VL-Embedding-2B hidden_size
normalize: bool = True # L2 normalization (per official implementation)
@dataclass
class RerankerConfig:
"""Configuration for reranker model."""
top_k: int = 5
@dataclass
class RAGConfig:
"""Configuration for RAG retrieval pipeline.
Per FDAM Technical Spec Section 3.
"""
top_k_retrieval: int = 10 # Initial retrieval count
top_k_rerank: int = 5 # Final results after reranking
similarity_threshold: float = 0.7 # Minimum similarity to include
# Default configurations
vision_config = VisionInferenceConfig() # Single 30B-A3B FP8 model
generation_config = GenerationInferenceConfig()
embedding_config = EmbeddingConfig()
reranker_config = RerankerConfig()
rag_config = RAGConfig()