ATCTrack-VLM / scripts /qwen_state_server.py
SunXiang2025's picture
Update: two-stage training, per-channel FiLM gate, cosine scheduler, 9B config
b3f019f verified
#!/usr/bin/env python
import argparse
import base64
import copy
import os
import sys
from io import BytesIO
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from PIL import Image
PROJECT_DIR = Path(__file__).resolve().parents[1]
if str(PROJECT_DIR) not in sys.path:
sys.path.insert(0, str(PROJECT_DIR))
from lib.config.atctrack.config import cfg, update_config_from_file
import lib.models.atctrack.atctrack as atctrack_module
from lib.models.atctrack.atctrack import QwenTargetStateEncoder, TARGET_STATE_TOKEN
class UpdateRequest(BaseModel):
template_image: str
candidate_image: str
template_bbox: List[float]
candidate_bbox: List[float]
caption: Optional[str] = None
object_name: Optional[str] = None
class QwenStateService:
def __init__(self, config_path, checkpoint_path, model_path, device):
update_config_from_file(config_path)
local_cfg = copy.deepcopy(cfg)
local_cfg.MODEL.TARGET_STATE.ENABLE = True
local_cfg.MODEL.TARGET_STATE.MODEL_PATH = model_path
local_cfg.MODEL.TARGET_STATE.USE_LORA = False
local_cfg.MODEL.TARGET_STATE.FREEZE_QWEN = True
local_cfg.MODEL.TARGET_STATE.TEACHER_ENABLE = False
self.device = torch.device(device)
self.device_map = os.environ.get("QWEN_STATE_DEVICE_MAP", "auto")
def _load_sharded_qwen(model_path):
try:
from transformers import AutoModelForImageTextToText
model_cls = AutoModelForImageTextToText
except ImportError:
from transformers import AutoModelForCausalLM
model_cls = AutoModelForCausalLM
kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16}
if self.device_map:
kwargs["device_map"] = self.device_map
max_memory = os.environ.get("QWEN_STATE_MAX_MEMORY")
if max_memory:
kwargs["max_memory"] = {idx: item for idx, item in enumerate(max_memory.split(","))}
return model_cls.from_pretrained(model_path, **kwargs)
original_loader = atctrack_module._load_qwen_target_state_model
atctrack_module._load_qwen_target_state_model = _load_sharded_qwen
try:
self.encoder = QwenTargetStateEncoder(local_cfg, tracker_dim=local_cfg.MODEL.HIDDEN_DIM)
finally:
atctrack_module._load_qwen_target_state_model = original_loader
self.encoder.projector.to(self.device)
self.encoder.film.to(self.device)
self.encoder.film_gate.data = self.encoder.film_gate.data.to(self.device)
self.encoder.eval()
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
net = checkpoint.get('net', {})
encoder_state = {}
prefix = 'target_state_encoder.'
for key, value in net.items():
if not key.startswith(prefix):
continue
stripped = key[len(prefix):]
# The served model has LoRA merged into Qwen, so only restore the non-Qwen trainable heads.
if stripped.startswith('qwen.'):
continue
encoder_state[stripped] = value
missing, unexpected = self.encoder.load_state_dict(encoder_state, strict=False)
missing_head = [k for k in missing if not k.startswith('qwen.')]
if missing_head:
raise RuntimeError(f'Missing non-Qwen encoder keys: {missing_head[:20]}')
if unexpected:
raise RuntimeError(f'Unexpected encoder keys: {unexpected[:20]}')
row = checkpoint.get('target_state_embedding')
if row is not None:
token_id = int(row['token_id'])
current_id = int(self.encoder.target_token_id)
if token_id != current_id or row.get('token') != TARGET_STATE_TOKEN:
raise RuntimeError(f'Target token mismatch checkpoint={row.get("token")}/{token_id}, model={TARGET_STATE_TOKEN}/{current_id}')
with torch.no_grad():
emb = self.encoder.qwen.get_input_embeddings().weight
emb[current_id].copy_(row['weight'].to(device=emb.device, dtype=emb.dtype))
print(f'Qwen state service loaded on {self.device}. Missing qwen/base keys ignored: {len(missing)}; non-Qwen restored: {len(encoder_state)}')
@staticmethod
def _decode_image(data):
if data.startswith('data:image'):
data = data.split(',', 1)[1]
image = Image.open(BytesIO(base64.b64decode(data))).convert('RGB')
return image
@staticmethod
def _image_to_tensor(image, device):
arr = np.asarray(image).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device)
mean = tensor.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = tensor.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
return (tensor - mean) / std
@staticmethod
def _parse_answer(output):
if not output:
return False
text = output.lower()
if '<answer>yes</answer>' in text:
return True
if '<answer>no</answer>' in text:
return False
return 'yes' in text.split() and 'no' not in text.split()
@torch.no_grad()
def update(self, req: UpdateRequest):
template = self._image_to_tensor(self._decode_image(req.template_image), self.device)
candidate = self._image_to_tensor(self._decode_image(req.candidate_image), self.device)
template_box = torch.tensor(req.template_bbox, device=self.device, dtype=torch.float32).view(1, 4)
candidate_box = torch.tensor(req.candidate_bbox, device=self.device, dtype=torch.float32).view(1, 4)
z_target, decisions, _, _, _, outputs = self.encoder(
[req.caption or 'the target object'],
template,
candidate,
template_box,
candidate_box,
self.device,
object_names=[req.object_name],
return_update_decision=True,
)
output = outputs[0] if outputs else None
decision = bool(decisions[0].item()) if decisions is not None else self._parse_answer(output)
return {
'decision': decision,
'output': output,
'z_target': z_target[0].detach().float().cpu().tolist(),
}
def build_app(args):
service = QwenStateService(args.config, args.checkpoint, args.model_path, args.device)
app = FastAPI()
@app.get('/health')
def health():
return {'status': 'ok'}
@app.post('/update')
def update(req: UpdateRequest):
return service.update(req)
return app
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default=str(PROJECT_DIR / 'experiments/atctrack/atctrack_qwen_state.yaml'))
parser.add_argument('--checkpoint', default=str(PROJECT_DIR / 'checkpoints/ATCTrack_ep0015.pth.tar'))
parser.add_argument('--model-path', default='/media/data/WWZ/SX/Qwen/Qwen3.5-9B-track')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--host', default='0.0.0.0')
parser.add_argument('--port', type=int, default=8001)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
import uvicorn
uvicorn.run(build_app(args), host=args.host, port=args.port, log_level='info')