| |
| import argparse |
| import base64 |
| import copy |
| import os |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from PIL import Image |
|
|
| PROJECT_DIR = Path(__file__).resolve().parents[1] |
| if str(PROJECT_DIR) not in sys.path: |
| sys.path.insert(0, str(PROJECT_DIR)) |
|
|
| from lib.config.atctrack.config import cfg, update_config_from_file |
| import lib.models.atctrack.atctrack as atctrack_module |
| from lib.models.atctrack.atctrack import QwenTargetStateEncoder, TARGET_STATE_TOKEN |
|
|
|
|
| class UpdateRequest(BaseModel): |
| template_image: str |
| candidate_image: str |
| template_bbox: List[float] |
| candidate_bbox: List[float] |
| caption: Optional[str] = None |
| object_name: Optional[str] = None |
|
|
|
|
| class QwenStateService: |
| def __init__(self, config_path, checkpoint_path, model_path, device): |
| update_config_from_file(config_path) |
| local_cfg = copy.deepcopy(cfg) |
| local_cfg.MODEL.TARGET_STATE.ENABLE = True |
| local_cfg.MODEL.TARGET_STATE.MODEL_PATH = model_path |
| local_cfg.MODEL.TARGET_STATE.USE_LORA = False |
| local_cfg.MODEL.TARGET_STATE.FREEZE_QWEN = True |
| local_cfg.MODEL.TARGET_STATE.TEACHER_ENABLE = False |
| self.device = torch.device(device) |
| self.device_map = os.environ.get("QWEN_STATE_DEVICE_MAP", "auto") |
|
|
| def _load_sharded_qwen(model_path): |
| try: |
| from transformers import AutoModelForImageTextToText |
| model_cls = AutoModelForImageTextToText |
| except ImportError: |
| from transformers import AutoModelForCausalLM |
| model_cls = AutoModelForCausalLM |
| kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16} |
| if self.device_map: |
| kwargs["device_map"] = self.device_map |
| max_memory = os.environ.get("QWEN_STATE_MAX_MEMORY") |
| if max_memory: |
| kwargs["max_memory"] = {idx: item for idx, item in enumerate(max_memory.split(","))} |
| return model_cls.from_pretrained(model_path, **kwargs) |
|
|
| original_loader = atctrack_module._load_qwen_target_state_model |
| atctrack_module._load_qwen_target_state_model = _load_sharded_qwen |
| try: |
| self.encoder = QwenTargetStateEncoder(local_cfg, tracker_dim=local_cfg.MODEL.HIDDEN_DIM) |
| finally: |
| atctrack_module._load_qwen_target_state_model = original_loader |
| self.encoder.projector.to(self.device) |
| self.encoder.film.to(self.device) |
| self.encoder.film_gate.data = self.encoder.film_gate.data.to(self.device) |
| self.encoder.eval() |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| net = checkpoint.get('net', {}) |
| encoder_state = {} |
| prefix = 'target_state_encoder.' |
| for key, value in net.items(): |
| if not key.startswith(prefix): |
| continue |
| stripped = key[len(prefix):] |
| |
| if stripped.startswith('qwen.'): |
| continue |
| encoder_state[stripped] = value |
| missing, unexpected = self.encoder.load_state_dict(encoder_state, strict=False) |
| missing_head = [k for k in missing if not k.startswith('qwen.')] |
| if missing_head: |
| raise RuntimeError(f'Missing non-Qwen encoder keys: {missing_head[:20]}') |
| if unexpected: |
| raise RuntimeError(f'Unexpected encoder keys: {unexpected[:20]}') |
| row = checkpoint.get('target_state_embedding') |
| if row is not None: |
| token_id = int(row['token_id']) |
| current_id = int(self.encoder.target_token_id) |
| if token_id != current_id or row.get('token') != TARGET_STATE_TOKEN: |
| raise RuntimeError(f'Target token mismatch checkpoint={row.get("token")}/{token_id}, model={TARGET_STATE_TOKEN}/{current_id}') |
| with torch.no_grad(): |
| emb = self.encoder.qwen.get_input_embeddings().weight |
| emb[current_id].copy_(row['weight'].to(device=emb.device, dtype=emb.dtype)) |
| print(f'Qwen state service loaded on {self.device}. Missing qwen/base keys ignored: {len(missing)}; non-Qwen restored: {len(encoder_state)}') |
|
|
| @staticmethod |
| def _decode_image(data): |
| if data.startswith('data:image'): |
| data = data.split(',', 1)[1] |
| image = Image.open(BytesIO(base64.b64decode(data))).convert('RGB') |
| return image |
|
|
| @staticmethod |
| def _image_to_tensor(image, device): |
| arr = np.asarray(image).astype(np.float32) / 255.0 |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device) |
| mean = tensor.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| std = tensor.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
| return (tensor - mean) / std |
|
|
| @staticmethod |
| def _parse_answer(output): |
| if not output: |
| return False |
| text = output.lower() |
| if '<answer>yes</answer>' in text: |
| return True |
| if '<answer>no</answer>' in text: |
| return False |
| return 'yes' in text.split() and 'no' not in text.split() |
|
|
| @torch.no_grad() |
| def update(self, req: UpdateRequest): |
| template = self._image_to_tensor(self._decode_image(req.template_image), self.device) |
| candidate = self._image_to_tensor(self._decode_image(req.candidate_image), self.device) |
| template_box = torch.tensor(req.template_bbox, device=self.device, dtype=torch.float32).view(1, 4) |
| candidate_box = torch.tensor(req.candidate_bbox, device=self.device, dtype=torch.float32).view(1, 4) |
| z_target, decisions, _, _, _, outputs = self.encoder( |
| [req.caption or 'the target object'], |
| template, |
| candidate, |
| template_box, |
| candidate_box, |
| self.device, |
| object_names=[req.object_name], |
| return_update_decision=True, |
| ) |
| output = outputs[0] if outputs else None |
| decision = bool(decisions[0].item()) if decisions is not None else self._parse_answer(output) |
| return { |
| 'decision': decision, |
| 'output': output, |
| 'z_target': z_target[0].detach().float().cpu().tolist(), |
| } |
|
|
|
|
| def build_app(args): |
| service = QwenStateService(args.config, args.checkpoint, args.model_path, args.device) |
| app = FastAPI() |
|
|
| @app.get('/health') |
| def health(): |
| return {'status': 'ok'} |
|
|
| @app.post('/update') |
| def update(req: UpdateRequest): |
| return service.update(req) |
|
|
| return app |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', default=str(PROJECT_DIR / 'experiments/atctrack/atctrack_qwen_state.yaml')) |
| parser.add_argument('--checkpoint', default=str(PROJECT_DIR / 'checkpoints/ATCTrack_ep0015.pth.tar')) |
| parser.add_argument('--model-path', default='/media/data/WWZ/SX/Qwen/Qwen3.5-9B-track') |
| parser.add_argument('--device', default='cuda:0') |
| parser.add_argument('--host', default='0.0.0.0') |
| parser.add_argument('--port', type=int, default=8001) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| args = parse_args() |
| import uvicorn |
| uvicorn.run(build_app(args), host=args.host, port=args.port, log_level='info') |
|
|