#!/usr/bin/env python import argparse import base64 import copy import os import sys from io import BytesIO from pathlib import Path from typing import List, Optional import numpy as np import torch from fastapi import FastAPI from pydantic import BaseModel from PIL import Image PROJECT_DIR = Path(__file__).resolve().parents[1] if str(PROJECT_DIR) not in sys.path: sys.path.insert(0, str(PROJECT_DIR)) from lib.config.atctrack.config import cfg, update_config_from_file import lib.models.atctrack.atctrack as atctrack_module from lib.models.atctrack.atctrack import QwenTargetStateEncoder, TARGET_STATE_TOKEN class UpdateRequest(BaseModel): template_image: str candidate_image: str template_bbox: List[float] candidate_bbox: List[float] caption: Optional[str] = None object_name: Optional[str] = None class QwenStateService: def __init__(self, config_path, checkpoint_path, model_path, device): update_config_from_file(config_path) local_cfg = copy.deepcopy(cfg) local_cfg.MODEL.TARGET_STATE.ENABLE = True local_cfg.MODEL.TARGET_STATE.MODEL_PATH = model_path local_cfg.MODEL.TARGET_STATE.USE_LORA = False local_cfg.MODEL.TARGET_STATE.FREEZE_QWEN = True local_cfg.MODEL.TARGET_STATE.TEACHER_ENABLE = False self.device = torch.device(device) self.device_map = os.environ.get("QWEN_STATE_DEVICE_MAP", "auto") def _load_sharded_qwen(model_path): try: from transformers import AutoModelForImageTextToText model_cls = AutoModelForImageTextToText except ImportError: from transformers import AutoModelForCausalLM model_cls = AutoModelForCausalLM kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16} if self.device_map: kwargs["device_map"] = self.device_map max_memory = os.environ.get("QWEN_STATE_MAX_MEMORY") if max_memory: kwargs["max_memory"] = {idx: item for idx, item in enumerate(max_memory.split(","))} return model_cls.from_pretrained(model_path, **kwargs) original_loader = atctrack_module._load_qwen_target_state_model atctrack_module._load_qwen_target_state_model = _load_sharded_qwen try: self.encoder = QwenTargetStateEncoder(local_cfg, tracker_dim=local_cfg.MODEL.HIDDEN_DIM) finally: atctrack_module._load_qwen_target_state_model = original_loader self.encoder.projector.to(self.device) self.encoder.film.to(self.device) self.encoder.film_gate.data = self.encoder.film_gate.data.to(self.device) self.encoder.eval() checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) net = checkpoint.get('net', {}) encoder_state = {} prefix = 'target_state_encoder.' for key, value in net.items(): if not key.startswith(prefix): continue stripped = key[len(prefix):] # The served model has LoRA merged into Qwen, so only restore the non-Qwen trainable heads. if stripped.startswith('qwen.'): continue encoder_state[stripped] = value missing, unexpected = self.encoder.load_state_dict(encoder_state, strict=False) missing_head = [k for k in missing if not k.startswith('qwen.')] if missing_head: raise RuntimeError(f'Missing non-Qwen encoder keys: {missing_head[:20]}') if unexpected: raise RuntimeError(f'Unexpected encoder keys: {unexpected[:20]}') row = checkpoint.get('target_state_embedding') if row is not None: token_id = int(row['token_id']) current_id = int(self.encoder.target_token_id) if token_id != current_id or row.get('token') != TARGET_STATE_TOKEN: raise RuntimeError(f'Target token mismatch checkpoint={row.get("token")}/{token_id}, model={TARGET_STATE_TOKEN}/{current_id}') with torch.no_grad(): emb = self.encoder.qwen.get_input_embeddings().weight emb[current_id].copy_(row['weight'].to(device=emb.device, dtype=emb.dtype)) print(f'Qwen state service loaded on {self.device}. Missing qwen/base keys ignored: {len(missing)}; non-Qwen restored: {len(encoder_state)}') @staticmethod def _decode_image(data): if data.startswith('data:image'): data = data.split(',', 1)[1] image = Image.open(BytesIO(base64.b64decode(data))).convert('RGB') return image @staticmethod def _image_to_tensor(image, device): arr = np.asarray(image).astype(np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device) mean = tensor.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) std = tensor.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) return (tensor - mean) / std @staticmethod def _parse_answer(output): if not output: return False text = output.lower() if 'yes' in text: return True if 'no' in text: return False return 'yes' in text.split() and 'no' not in text.split() @torch.no_grad() def update(self, req: UpdateRequest): template = self._image_to_tensor(self._decode_image(req.template_image), self.device) candidate = self._image_to_tensor(self._decode_image(req.candidate_image), self.device) template_box = torch.tensor(req.template_bbox, device=self.device, dtype=torch.float32).view(1, 4) candidate_box = torch.tensor(req.candidate_bbox, device=self.device, dtype=torch.float32).view(1, 4) z_target, decisions, _, _, _, outputs = self.encoder( [req.caption or 'the target object'], template, candidate, template_box, candidate_box, self.device, object_names=[req.object_name], return_update_decision=True, ) output = outputs[0] if outputs else None decision = bool(decisions[0].item()) if decisions is not None else self._parse_answer(output) return { 'decision': decision, 'output': output, 'z_target': z_target[0].detach().float().cpu().tolist(), } def build_app(args): service = QwenStateService(args.config, args.checkpoint, args.model_path, args.device) app = FastAPI() @app.get('/health') def health(): return {'status': 'ok'} @app.post('/update') def update(req: UpdateRequest): return service.update(req) return app def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--config', default=str(PROJECT_DIR / 'experiments/atctrack/atctrack_qwen_state.yaml')) parser.add_argument('--checkpoint', default=str(PROJECT_DIR / 'checkpoints/ATCTrack_ep0015.pth.tar')) parser.add_argument('--model-path', default='/media/data/WWZ/SX/Qwen/Qwen3.5-9B-track') parser.add_argument('--device', default='cuda:0') parser.add_argument('--host', default='0.0.0.0') parser.add_argument('--port', type=int, default=8001) return parser.parse_args() if __name__ == '__main__': args = parse_args() import uvicorn uvicorn.run(build_app(args), host=args.host, port=args.port, log_level='info')