File size: 13,965 Bytes
b3f019f
 
ede4b32
b3f019f
 
 
 
ede4b32
 
b3f019f
 
 
 
 
 
ede4b32
b3f019f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ede4b32
 
 
 
 
 
 
b3f019f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ede4b32
b3f019f
 
 
 
 
 
 
 
 
 
 
 
 
 
ede4b32
 
 
 
 
b3f019f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ede4b32
b3f019f
 
 
 
 
 
 
 
 
ede4b32
b3f019f
 
 
 
ede4b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f019f
ede4b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f019f
 
 
 
ede4b32
 
 
 
 
 
b3f019f
 
ede4b32
 
 
 
b3f019f
 
ede4b32
 
 
 
 
 
b3f019f
 
ede4b32
 
b3f019f
 
 
 
 
 
 
 
 
 
 
 
ede4b32
 
 
b3f019f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python
import argparse
import asyncio
import base64
import copy
import os
import sys
import time
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from PIL import Image

PROJECT_DIR = Path(__file__).resolve().parents[1]
if str(PROJECT_DIR) not in sys.path:
    sys.path.insert(0, str(PROJECT_DIR))

from lib.config.atctrack.config import cfg, update_config_from_file
import lib.models.atctrack.atctrack as atctrack_module
from lib.models.atctrack.atctrack import QwenTargetStateEncoder, TARGET_STATE_TOKEN


class UpdateRequest(BaseModel):
    template_image: str
    candidate_image: str
    template_bbox: List[float]
    candidate_bbox: List[float]
    caption: Optional[str] = None
    object_name: Optional[str] = None


@dataclass
class PendingItem:
    req: UpdateRequest
    future: asyncio.Future
    enqueue_time: float


class QwenStateService:
    def __init__(self, config_path, checkpoint_path, model_path, device):
        update_config_from_file(config_path)
        local_cfg = copy.deepcopy(cfg)
        local_cfg.MODEL.TARGET_STATE.ENABLE = True
        local_cfg.MODEL.TARGET_STATE.MODEL_PATH = model_path
        local_cfg.MODEL.TARGET_STATE.USE_LORA = False
        local_cfg.MODEL.TARGET_STATE.FREEZE_QWEN = True
        local_cfg.MODEL.TARGET_STATE.TEACHER_ENABLE = False
        self.device = torch.device(device)
        self.device_map = os.environ.get("QWEN_STATE_DEVICE_MAP", "auto")

        def _load_sharded_qwen(model_path):
            try:
                from transformers import AutoModelForImageTextToText
                model_cls = AutoModelForImageTextToText
            except ImportError:
                from transformers import AutoModelForCausalLM
                model_cls = AutoModelForCausalLM
            kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16}
            if self.device_map:
                kwargs["device_map"] = self.device_map
            max_memory = os.environ.get("QWEN_STATE_MAX_MEMORY")
            if max_memory:
                kwargs["max_memory"] = {idx: item for idx, item in enumerate(max_memory.split(","))}
            return model_cls.from_pretrained(model_path, **kwargs)

        original_loader = atctrack_module._load_qwen_target_state_model
        atctrack_module._load_qwen_target_state_model = _load_sharded_qwen
        try:
            self.encoder = QwenTargetStateEncoder(local_cfg, tracker_dim=local_cfg.MODEL.HIDDEN_DIM)
        finally:
            atctrack_module._load_qwen_target_state_model = original_loader
        self.encoder.projector.to(self.device)
        self.encoder.film_ln.to(self.device)
        self.encoder.film.to(self.device)
        self.encoder.film_gate.data = self.encoder.film_gate.data.to(self.device)
        self.encoder.eval()
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        net = checkpoint.get('net', {})
        encoder_state = {}
        prefix = 'target_state_encoder.'
        for key, value in net.items():
            if not key.startswith(prefix):
                continue
            stripped = key[len(prefix):]
            if stripped.startswith('qwen.'):
                continue
            encoder_state[stripped] = value
        if 'film_gate' in encoder_state and encoder_state['film_gate'].shape != self.encoder.film_gate.shape:
            old_shape = encoder_state['film_gate'].shape
            encoder_state['film_gate'] = encoder_state['film_gate'].expand(self.encoder.film_gate.shape).clone()
            print(f'  [compat] expanded film_gate from {list(old_shape)} to {list(self.encoder.film_gate.shape)}')

        missing, unexpected = self.encoder.load_state_dict(encoder_state, strict=False)
        missing_head = [k for k in missing if not k.startswith('qwen.')]
        if missing_head:
            raise RuntimeError(f'Missing non-Qwen encoder keys: {missing_head[:20]}')
        if unexpected:
            raise RuntimeError(f'Unexpected encoder keys: {unexpected[:20]}')
        row = checkpoint.get('target_state_embedding')
        if row is not None:
            token_id = int(row['token_id'])
            current_id = int(self.encoder.target_token_id)
            if token_id != current_id or row.get('token') != TARGET_STATE_TOKEN:
                raise RuntimeError(f'Target token mismatch checkpoint={row.get("token")}/{token_id}, model={TARGET_STATE_TOKEN}/{current_id}')
            with torch.no_grad():
                emb = self.encoder.qwen.get_input_embeddings().weight
                emb[current_id].copy_(row['weight'].to(device=emb.device, dtype=emb.dtype))
        print(f'Qwen state service loaded on {self.device}. Missing qwen/base keys ignored: {len(missing)}; non-Qwen restored: {len(encoder_state)}')

    @staticmethod
    def _decode_image(data):
        if data.startswith('data:image'):
            data = data.split(',', 1)[1]
        return Image.open(BytesIO(base64.b64decode(data))).convert('RGB')

    @staticmethod
    def _image_to_tensor(image, device):
        arr = np.asarray(image).astype(np.float32) / 255.0
        tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device)
        mean = tensor.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        std = tensor.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
        return (tensor - mean) / std

    def _prepare_item(self, req: UpdateRequest):
        template = self._image_to_tensor(self._decode_image(req.template_image), self.device)
        candidate = self._image_to_tensor(self._decode_image(req.candidate_image), self.device)
        template_box = torch.tensor(req.template_bbox, device=self.device, dtype=torch.float32).view(1, 4)
        candidate_box = torch.tensor(req.candidate_bbox, device=self.device, dtype=torch.float32).view(1, 4)
        return template, candidate, template_box, candidate_box

    @torch.no_grad()
    def update(self, req: UpdateRequest):
        return self.batch_update([req])[0]

    @torch.no_grad()
    def batch_update(self, reqs: List[UpdateRequest]):
        if not reqs:
            return []
        templates, candidates, template_boxes, candidate_boxes = [], [], [], []
        for req in reqs:
            template, candidate, template_box, candidate_box = self._prepare_item(req)
            templates.append(template)
            candidates.append(candidate)
            template_boxes.append(template_box)
            candidate_boxes.append(candidate_box)
        template_tensor = torch.cat(templates, dim=0)
        candidate_tensor = torch.cat(candidates, dim=0)
        template_box_tensor = torch.cat(template_boxes, dim=0)
        candidate_box_tensor = torch.cat(candidate_boxes, dim=0)

        captions = [req.caption or 'the target object' for req in reqs]
        object_names = [req.object_name for req in reqs]
        prompts = [self.encoder._build_prompt(caption, object_name) for caption, object_name in zip(captions, object_names)]
        template_pils = self.encoder._tensor_batch_to_pil(template_tensor, template_box_tensor)
        candidate_pils = self.encoder._tensor_batch_to_pil(candidate_tensor, candidate_box_tensor)

        messages = []
        for prompt, template_pil, candidate_pil in zip(prompts, template_pils, candidate_pils):
            messages.append([
                {
                    'role': 'user',
                    'content': [
                        {'type': 'image', 'image': template_pil},
                        {'type': 'image', 'image': candidate_pil},
                        {'type': 'text', 'text': prompt},
                    ],
                }
            ])
        texts = [self.encoder._apply_qwen_chat_template(message) for message in messages]
        images = [[template_pil, candidate_pil] for template_pil, candidate_pil in zip(template_pils, candidate_pils)]

        expanded_texts, expanded_images, expanded_decisions = [], [], []
        for text, image_pair in zip(texts, images):
            for decision in (True, False):
                expanded_texts.append(text + self.encoder._target_state_answer_text(decision))
                expanded_images.append(image_pair)
                expanded_decisions.append(decision)

        tokenized = self.encoder.processor(
            text=expanded_texts,
            images=expanded_images,
            padding=True,
            return_tensors='pt',
        ).to(self.device)
        _, answer_token_positions, target_token_positions = self.encoder._build_forward_labels(
            tokenized.input_ids,
            expanded_decisions,
            [True] * len(expanded_decisions),
        )
        outputs = self.encoder._qwen_forward_with_target_embedding(tokenized, labels=None)
        log_probs = F.log_softmax(outputs.logits.float(), dim=-1)

        score_by_row = tokenized.input_ids.new_zeros((len(expanded_decisions),), dtype=torch.float32).to(self.device)
        for batch_idx, pos in answer_token_positions:
            if pos > 0:
                target_id = int(tokenized.input_ids[batch_idx, pos].item())
                score_by_row[batch_idx] += log_probs[batch_idx, pos - 1, target_id]

        h_all = self.encoder._target_hidden_from_forward(
            outputs.hidden_states[-1], tokenized.input_ids, target_token_positions
        )
        results = []
        for i in range(len(reqs)):
            yes_row = 2 * i
            no_row = yes_row + 1
            decision = bool(score_by_row[yes_row].item() >= score_by_row[no_row].item())
            chosen_row = yes_row if decision else no_row
            z_target = self.encoder.projector(h_all[chosen_row:chosen_row + 1])
            output = self.encoder._target_state_answer_text(decision)
            results.append({
                'decision': decision,
                'output': output,
                'yes_score': float(score_by_row[yes_row].item()),
                'no_score': float(score_by_row[no_row].item()),
                'z_target': z_target[0].detach().float().cpu().tolist(),
            })
        return results


class BatchDispatcher:
    def __init__(self, service: QwenStateService, batch_size: int, wait_ms: int, max_queue_size: int):
        self.service = service
        self.batch_size = max(1, int(batch_size))
        self.wait_s = max(0.0, float(wait_ms) / 1000.0)
        self.queue = asyncio.Queue(maxsize=max_queue_size)
        self.worker_task = None

    def start(self):
        if self.worker_task is None:
            self.worker_task = asyncio.create_task(self._worker())

    async def submit(self, req: UpdateRequest):
        loop = asyncio.get_running_loop()
        future = loop.create_future()
        await self.queue.put(PendingItem(req=req, future=future, enqueue_time=time.time()))
        return await future

    async def _worker(self):
        while True:
            first = await self.queue.get()
            batch = [first]
            deadline = time.time() + self.wait_s
            while len(batch) < self.batch_size:
                timeout = max(0.0, deadline - time.time())
                if timeout <= 0.0:
                    break
                try:
                    item = await asyncio.wait_for(self.queue.get(), timeout=timeout)
                except asyncio.TimeoutError:
                    break
                batch.append(item)

            active = [item for item in batch if not item.future.cancelled()]
            if not active:
                continue
            try:
                results = await asyncio.to_thread(self.service.batch_update, [item.req for item in active])
                for item, result in zip(active, results):
                    if not item.future.cancelled():
                        item.future.set_result(result)
            except Exception as exc:
                for item in active:
                    if not item.future.cancelled():
                        item.future.set_exception(exc)


def build_app(args):
    service = QwenStateService(args.config, args.checkpoint, args.model_path, args.device)
    dispatcher = BatchDispatcher(
        service,
        batch_size=args.batch_size,
        wait_ms=args.batch_wait_ms,
        max_queue_size=args.max_queue_size,
    )
    app = FastAPI()

    @app.on_event('startup')
    async def startup():
        dispatcher.start()

    @app.get('/health')
    def health():
        return {
            'status': 'ok',
            'batch_size': dispatcher.batch_size,
            'batch_wait_ms': int(dispatcher.wait_s * 1000),
            'queue_size': dispatcher.queue.qsize(),
        }

    @app.post('/update')
    async def update(req: UpdateRequest):
        return await dispatcher.submit(req)

    return app


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=str(PROJECT_DIR / 'experiments/atctrack/atctrack_qwen_state.yaml'))
    parser.add_argument('--checkpoint', default=str(PROJECT_DIR / 'checkpoints/ATCTrack_ep0015.pth.tar'))
    parser.add_argument('--model-path', default='/media/data/WWZ/SX/Qwen/Qwen3.5-9B-track')
    parser.add_argument('--device', default='cuda:0')
    parser.add_argument('--host', default='0.0.0.0')
    parser.add_argument('--port', type=int, default=8001)
    parser.add_argument('--batch-size', type=int, default=int(os.environ.get('QWEN_STATE_BATCH_SIZE', '4')))
    parser.add_argument('--batch-wait-ms', type=int, default=int(os.environ.get('QWEN_STATE_BATCH_WAIT_MS', '30')))
    parser.add_argument('--max-queue-size', type=int, default=int(os.environ.get('QWEN_STATE_MAX_QUEUE_SIZE', '64')))
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    import uvicorn
    uvicorn.run(build_app(args), host=args.host, port=args.port, log_level='info')