| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import traceback |
| from typing import Any, Dict |
|
|
| import numpy as np |
| import torch |
| from fastapi import FastAPI |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
| import uvicorn |
| import json_numpy |
| import cv2 |
|
|
| from transformers import PreTrainedModel |
| from .modeling_florence2 import Florence2ForConditionalGeneration |
| from .transformer import SoftPromptedTransformer |
| from .action_hub import build_action_space |
| from .configuration_xvla import XVLAConfig |
|
|
|
|
| class XVLA(PreTrainedModel): |
| """ |
| XVLA: HuggingFace-compatible Vision-Language-Action policy. |
| |
| Components: |
| • Florence2 encoder-only backbone (vision-language) |
| • SoftPromptedTransformer (temporal/action head) |
| • Action space (pre/post-processing + loss) |
| """ |
| config_class = XVLAConfig |
| base_model_prefix = "xvla" |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: XVLAConfig, *args, **kwargs): |
| super().__init__(config, *args, **kwargs) |
|
|
| |
| self.num_actions: int = config.num_actions |
| self.use_proprio: bool = config.use_proprio |
| self.action_mode: str = config.action_mode.lower() |
| |
| self.action_space = build_action_space(config.action_mode.lower()) |
| dim_action = self.action_space.dim_action |
| dim_proprio = getattr(self.action_space, "dim_proprio", dim_action) |
|
|
| |
| self.vlm = Florence2ForConditionalGeneration(config.florence_config) |
| if hasattr(self.vlm, "language_model"): |
| lm = self.vlm.language_model |
| if hasattr(lm, "model") and hasattr(lm.model, "decoder"): |
| del lm.model.decoder |
| if hasattr(lm, "lm_head"): |
| del lm.lm_head |
|
|
| projection_dim = getattr(self.vlm.config, "projection_dim", None) |
| if projection_dim is None: |
| raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.") |
|
|
| |
| self.transformer = SoftPromptedTransformer( |
| hidden_size=config.hidden_size, |
| multi_modal_input_size=projection_dim, |
| depth=config.depth, |
| num_heads=config.num_heads, |
| mlp_ratio=config.mlp_ratio, |
| num_domains=config.num_domains, |
| dim_action=dim_action, |
| dim_propio=dim_proprio, |
| len_soft_prompts=config.len_soft_prompts, |
| dim_time=config.dim_time, |
| max_len_seq=config.max_len_seq, |
| use_hetero_proj=config.use_hetero_proj, |
| ) |
|
|
| |
| self.app: FastAPI | None = None |
|
|
| |
| def forward_vlm( |
| self, |
| input_ids: torch.LongTensor, |
| pixel_values: torch.FloatTensor, |
| image_mask: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Encode text + multi-view images via Florence2 encoder. |
| |
| Returns: |
| { "vlm_features": [B, T_enc, D], "aux_visual_inputs": [B, (V-1)*N, D] } |
| """ |
| B, V = pixel_values.shape[:2] |
| flat_mask = image_mask.view(-1).to(torch.bool) |
| flat_images = pixel_values.flatten(0, 1) |
|
|
| num_valid = int(flat_mask.sum().item()) |
| if num_valid == 0: |
| raise ValueError("At least one image view must be valid per batch.") |
|
|
| valid_images = flat_images[flat_mask] |
| valid_feats = self.vlm._encode_image(valid_images) |
| N, D = valid_feats.shape[1:] |
|
|
| image_features = valid_feats.new_zeros((B * V, N, D)) |
| image_features[flat_mask] = valid_feats |
| image_features = image_features.view(B, V, N, D) |
|
|
| inputs_embeds = self.vlm.get_input_embeddings()(input_ids) |
|
|
| merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features( |
| image_features[:, 0], |
| inputs_embeds, |
| ) |
|
|
| enc_out = self.vlm.language_model.model.encoder( |
| attention_mask=attention_mask, |
| inputs_embeds=merged_embeds, |
| )[0] |
|
|
| aux_visual_inputs = image_features[:, 1:].reshape(B, -1, D) |
| return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs} |
|
|
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| image_input: torch.FloatTensor, |
| image_mask: torch.Tensor, |
| domain_id: torch.LongTensor, |
| proprio: torch.Tensor, |
| action: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| 1) Encode multimodal inputs. |
| 2) Diffusion-style noisy mixture of actions: x_t = t*noise + (1-t)*gt. |
| 3) Space-specific preprocessing, prediction, and supervised loss. |
| """ |
| enc = self.forward_vlm(input_ids, image_input, image_mask) |
|
|
| B = input_ids.shape[0] |
| t = (torch.rand(1, device=input_ids.device) |
| + torch.arange(B, device=input_ids.device) / B) % (1 - 1e-5) |
|
|
| action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1) |
| proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy) |
|
|
| pred_action = self.transformer( |
| domain_id=domain_id, |
| action_with_noise=action_noisy_m, |
| t=t, |
| proprio=proprio_m, |
| **enc, |
| ) |
| return self.action_space.compute_loss(pred_action, action) |
|
|
| |
| @torch.no_grad() |
| def generate_actions( |
| self, |
| input_ids: torch.LongTensor, |
| image_input: torch.FloatTensor, |
| image_mask: torch.Tensor, |
| domain_id: torch.LongTensor, |
| proprio: torch.Tensor, |
| steps: int = 10, |
| ) -> torch.Tensor: |
| """ |
| Iterative denoising (linear schedule). |
| Applies action_space.postprocess at the end (e.g., sigmoid on gripper). |
| """ |
| self.eval() |
| enc = self.forward_vlm(input_ids, image_input, image_mask) |
|
|
| B = input_ids.shape[0] |
| D = self.action_space.dim_action |
|
|
| x1 = torch.randn(B, self.num_actions, D, device=proprio.device, dtype=proprio.dtype) |
| action = torch.zeros_like(x1) |
|
|
| steps = max(1, int(steps)) |
| for i in range(steps, 0, -1): |
| t = torch.full((B,), i / steps, device=proprio.device, dtype=proprio.dtype) |
| x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1) |
| proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t) |
| action = self.transformer( |
| domain_id=domain_id, |
| action_with_noise=x_t_m, |
| proprio=proprio_m, |
| t=t, |
| **enc, |
| ) |
| return self.action_space.postprocess(action) |
|
|
| |
| def _build_app(self, processor): |
| """ |
| Minimal FastAPI app for XVLA inference. |
| |
| Args: |
| processor: callable(images, text) -> Dict[str, torch.Tensor] |
| expected keys: "input_ids", "image_input", "image_mask" |
| """ |
| if self.app is not None: |
| return |
|
|
| app = FastAPI() |
|
|
| @app.post("/act") |
| def act(payload: Dict[str, Any]): |
| try: |
| self.eval() |
| |
| images = [] |
| for key in ("image0", "image1", "image2"): |
| if key not in payload: continue |
| v = json_numpy.loads(payload[key]) |
| if isinstance(v, np.ndarray): |
| if v.ndim == 1: |
| v = cv2.imdecode(v, cv2.IMREAD_COLOR) |
| images.append(Image.fromarray(v)) |
| elif isinstance(v, (list, tuple)): |
| images.append(Image.fromarray(np.array(v))) |
| elif isinstance(v, str): |
| images.append(Image.open(v)) |
| if not images: |
| return JSONResponse({"error": "No valid images found."}, status_code=400) |
|
|
| |
| inputs = processor(images, payload["language_instruction"]) |
| if not {"input_ids", "image_input", "image_mask"}.issubset(inputs): |
| return JSONResponse({"error": "Processor returned incomplete inputs."}, status_code=400) |
|
|
| |
| proprio = torch.as_tensor(np.asarray(json_numpy.loads(payload["proprio"]))) |
| domain_id = torch.tensor([int(payload["domain_id"])], dtype=torch.long) |
|
|
| |
| device = next(self.parameters()).device |
| dtype = next(self.parameters()).dtype |
|
|
| def to_model(t: torch.Tensor) -> torch.Tensor: |
| if not isinstance(t, torch.Tensor): |
| t = torch.as_tensor(t) |
| |
| return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device) |
|
|
| inputs = {k: to_model(v) for k, v in inputs.items()} |
| inputs.update({ |
| "proprio": to_model(proprio.unsqueeze(0)), |
| "domain_id": domain_id.to(device), |
| }) |
|
|
| |
| steps = int(payload.get("steps", 10)) |
| action = self.generate_actions(**inputs, steps=steps).squeeze(0).float().cpu().numpy() |
| return JSONResponse({"action": action.tolist()}) |
|
|
| except Exception: |
| logging.error(traceback.format_exc()) |
| return JSONResponse({"error": "Request failed"}, status_code=400) |
|
|
| self.app = app |
|
|
| def run(self, processor, host: str = "0.0.0.0", port: int = 8000): |
| """ |
| Launch the FastAPI service. |
| """ |
| self._build_app(processor) |
| assert self.app is not None |
| uvicorn.run(self.app, host=host, port=port) |
|
|