File size: 13,965 Bytes
b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f ede4b32 b3f019f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | #!/usr/bin/env python
import argparse
import asyncio
import base64
import copy
import os
import sys
import time
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from PIL import Image
PROJECT_DIR = Path(__file__).resolve().parents[1]
if str(PROJECT_DIR) not in sys.path:
sys.path.insert(0, str(PROJECT_DIR))
from lib.config.atctrack.config import cfg, update_config_from_file
import lib.models.atctrack.atctrack as atctrack_module
from lib.models.atctrack.atctrack import QwenTargetStateEncoder, TARGET_STATE_TOKEN
class UpdateRequest(BaseModel):
template_image: str
candidate_image: str
template_bbox: List[float]
candidate_bbox: List[float]
caption: Optional[str] = None
object_name: Optional[str] = None
@dataclass
class PendingItem:
req: UpdateRequest
future: asyncio.Future
enqueue_time: float
class QwenStateService:
def __init__(self, config_path, checkpoint_path, model_path, device):
update_config_from_file(config_path)
local_cfg = copy.deepcopy(cfg)
local_cfg.MODEL.TARGET_STATE.ENABLE = True
local_cfg.MODEL.TARGET_STATE.MODEL_PATH = model_path
local_cfg.MODEL.TARGET_STATE.USE_LORA = False
local_cfg.MODEL.TARGET_STATE.FREEZE_QWEN = True
local_cfg.MODEL.TARGET_STATE.TEACHER_ENABLE = False
self.device = torch.device(device)
self.device_map = os.environ.get("QWEN_STATE_DEVICE_MAP", "auto")
def _load_sharded_qwen(model_path):
try:
from transformers import AutoModelForImageTextToText
model_cls = AutoModelForImageTextToText
except ImportError:
from transformers import AutoModelForCausalLM
model_cls = AutoModelForCausalLM
kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16}
if self.device_map:
kwargs["device_map"] = self.device_map
max_memory = os.environ.get("QWEN_STATE_MAX_MEMORY")
if max_memory:
kwargs["max_memory"] = {idx: item for idx, item in enumerate(max_memory.split(","))}
return model_cls.from_pretrained(model_path, **kwargs)
original_loader = atctrack_module._load_qwen_target_state_model
atctrack_module._load_qwen_target_state_model = _load_sharded_qwen
try:
self.encoder = QwenTargetStateEncoder(local_cfg, tracker_dim=local_cfg.MODEL.HIDDEN_DIM)
finally:
atctrack_module._load_qwen_target_state_model = original_loader
self.encoder.projector.to(self.device)
self.encoder.film_ln.to(self.device)
self.encoder.film.to(self.device)
self.encoder.film_gate.data = self.encoder.film_gate.data.to(self.device)
self.encoder.eval()
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
net = checkpoint.get('net', {})
encoder_state = {}
prefix = 'target_state_encoder.'
for key, value in net.items():
if not key.startswith(prefix):
continue
stripped = key[len(prefix):]
if stripped.startswith('qwen.'):
continue
encoder_state[stripped] = value
if 'film_gate' in encoder_state and encoder_state['film_gate'].shape != self.encoder.film_gate.shape:
old_shape = encoder_state['film_gate'].shape
encoder_state['film_gate'] = encoder_state['film_gate'].expand(self.encoder.film_gate.shape).clone()
print(f' [compat] expanded film_gate from {list(old_shape)} to {list(self.encoder.film_gate.shape)}')
missing, unexpected = self.encoder.load_state_dict(encoder_state, strict=False)
missing_head = [k for k in missing if not k.startswith('qwen.')]
if missing_head:
raise RuntimeError(f'Missing non-Qwen encoder keys: {missing_head[:20]}')
if unexpected:
raise RuntimeError(f'Unexpected encoder keys: {unexpected[:20]}')
row = checkpoint.get('target_state_embedding')
if row is not None:
token_id = int(row['token_id'])
current_id = int(self.encoder.target_token_id)
if token_id != current_id or row.get('token') != TARGET_STATE_TOKEN:
raise RuntimeError(f'Target token mismatch checkpoint={row.get("token")}/{token_id}, model={TARGET_STATE_TOKEN}/{current_id}')
with torch.no_grad():
emb = self.encoder.qwen.get_input_embeddings().weight
emb[current_id].copy_(row['weight'].to(device=emb.device, dtype=emb.dtype))
print(f'Qwen state service loaded on {self.device}. Missing qwen/base keys ignored: {len(missing)}; non-Qwen restored: {len(encoder_state)}')
@staticmethod
def _decode_image(data):
if data.startswith('data:image'):
data = data.split(',', 1)[1]
return Image.open(BytesIO(base64.b64decode(data))).convert('RGB')
@staticmethod
def _image_to_tensor(image, device):
arr = np.asarray(image).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device)
mean = tensor.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = tensor.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
return (tensor - mean) / std
def _prepare_item(self, req: UpdateRequest):
template = self._image_to_tensor(self._decode_image(req.template_image), self.device)
candidate = self._image_to_tensor(self._decode_image(req.candidate_image), self.device)
template_box = torch.tensor(req.template_bbox, device=self.device, dtype=torch.float32).view(1, 4)
candidate_box = torch.tensor(req.candidate_bbox, device=self.device, dtype=torch.float32).view(1, 4)
return template, candidate, template_box, candidate_box
@torch.no_grad()
def update(self, req: UpdateRequest):
return self.batch_update([req])[0]
@torch.no_grad()
def batch_update(self, reqs: List[UpdateRequest]):
if not reqs:
return []
templates, candidates, template_boxes, candidate_boxes = [], [], [], []
for req in reqs:
template, candidate, template_box, candidate_box = self._prepare_item(req)
templates.append(template)
candidates.append(candidate)
template_boxes.append(template_box)
candidate_boxes.append(candidate_box)
template_tensor = torch.cat(templates, dim=0)
candidate_tensor = torch.cat(candidates, dim=0)
template_box_tensor = torch.cat(template_boxes, dim=0)
candidate_box_tensor = torch.cat(candidate_boxes, dim=0)
captions = [req.caption or 'the target object' for req in reqs]
object_names = [req.object_name for req in reqs]
prompts = [self.encoder._build_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)]
template_pils = self.encoder._tensor_batch_to_pil(template_tensor, template_box_tensor)
candidate_pils = self.encoder._tensor_batch_to_pil(candidate_tensor, candidate_box_tensor)
messages = []
for prompt, template_pil, candidate_pil in zip(prompts, template_pils, candidate_pils):
messages.append([
{
'role': 'user',
'content': [
{'type': 'image', 'image': template_pil},
{'type': 'image', 'image': candidate_pil},
{'type': 'text', 'text': prompt},
],
}
])
texts = [self.encoder._apply_qwen_chat_template(message) for message in messages]
images = [[template_pil, candidate_pil] for template_pil, candidate_pil in zip(template_pils, candidate_pils)]
expanded_texts, expanded_images, expanded_decisions = [], [], []
for text, image_pair in zip(texts, images):
for decision in (True, False):
expanded_texts.append(text + self.encoder._target_state_answer_text(decision))
expanded_images.append(image_pair)
expanded_decisions.append(decision)
tokenized = self.encoder.processor(
text=expanded_texts,
images=expanded_images,
padding=True,
return_tensors='pt',
).to(self.device)
_, answer_token_positions, target_token_positions = self.encoder._build_forward_labels(
tokenized.input_ids,
expanded_decisions,
[True] * len(expanded_decisions),
)
outputs = self.encoder._qwen_forward_with_target_embedding(tokenized, labels=None)
log_probs = F.log_softmax(outputs.logits.float(), dim=-1)
score_by_row = tokenized.input_ids.new_zeros((len(expanded_decisions),), dtype=torch.float32).to(self.device)
for batch_idx, pos in answer_token_positions:
if pos > 0:
target_id = int(tokenized.input_ids[batch_idx, pos].item())
score_by_row[batch_idx] += log_probs[batch_idx, pos - 1, target_id]
h_all = self.encoder._target_hidden_from_forward(
outputs.hidden_states[-1], tokenized.input_ids, target_token_positions
)
results = []
for i in range(len(reqs)):
yes_row = 2 * i
no_row = yes_row + 1
decision = bool(score_by_row[yes_row].item() >= score_by_row[no_row].item())
chosen_row = yes_row if decision else no_row
z_target = self.encoder.projector(h_all[chosen_row:chosen_row + 1])
output = self.encoder._target_state_answer_text(decision)
results.append({
'decision': decision,
'output': output,
'yes_score': float(score_by_row[yes_row].item()),
'no_score': float(score_by_row[no_row].item()),
'z_target': z_target[0].detach().float().cpu().tolist(),
})
return results
class BatchDispatcher:
def __init__(self, service: QwenStateService, batch_size: int, wait_ms: int, max_queue_size: int):
self.service = service
self.batch_size = max(1, int(batch_size))
self.wait_s = max(0.0, float(wait_ms) / 1000.0)
self.queue = asyncio.Queue(maxsize=max_queue_size)
self.worker_task = None
def start(self):
if self.worker_task is None:
self.worker_task = asyncio.create_task(self._worker())
async def submit(self, req: UpdateRequest):
loop = asyncio.get_running_loop()
future = loop.create_future()
await self.queue.put(PendingItem(req=req, future=future, enqueue_time=time.time()))
return await future
async def _worker(self):
while True:
first = await self.queue.get()
batch = [first]
deadline = time.time() + self.wait_s
while len(batch) < self.batch_size:
timeout = max(0.0, deadline - time.time())
if timeout <= 0.0:
break
try:
item = await asyncio.wait_for(self.queue.get(), timeout=timeout)
except asyncio.TimeoutError:
break
batch.append(item)
active = [item for item in batch if not item.future.cancelled()]
if not active:
continue
try:
results = await asyncio.to_thread(self.service.batch_update, [item.req for item in active])
for item, result in zip(active, results):
if not item.future.cancelled():
item.future.set_result(result)
except Exception as exc:
for item in active:
if not item.future.cancelled():
item.future.set_exception(exc)
def build_app(args):
service = QwenStateService(args.config, args.checkpoint, args.model_path, args.device)
dispatcher = BatchDispatcher(
service,
batch_size=args.batch_size,
wait_ms=args.batch_wait_ms,
max_queue_size=args.max_queue_size,
)
app = FastAPI()
@app.on_event('startup')
async def startup():
dispatcher.start()
@app.get('/health')
def health():
return {
'status': 'ok',
'batch_size': dispatcher.batch_size,
'batch_wait_ms': int(dispatcher.wait_s * 1000),
'queue_size': dispatcher.queue.qsize(),
}
@app.post('/update')
async def update(req: UpdateRequest):
return await dispatcher.submit(req)
return app
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default=str(PROJECT_DIR / 'experiments/atctrack/atctrack_qwen_state.yaml'))
parser.add_argument('--checkpoint', default=str(PROJECT_DIR / 'checkpoints/ATCTrack_ep0015.pth.tar'))
parser.add_argument('--model-path', default='/media/data/WWZ/SX/Qwen/Qwen3.5-9B-track')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--host', default='0.0.0.0')
parser.add_argument('--port', type=int, default=8001)
parser.add_argument('--batch-size', type=int, default=int(os.environ.get('QWEN_STATE_BATCH_SIZE', '4')))
parser.add_argument('--batch-wait-ms', type=int, default=int(os.environ.get('QWEN_STATE_BATCH_WAIT_MS', '30')))
parser.add_argument('--max-queue-size', type=int, default=int(os.environ.get('QWEN_STATE_MAX_QUEUE_SIZE', '64')))
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
import uvicorn
uvicorn.run(build_app(args), host=args.host, port=args.port, log_level='info')
|