Spaces:
Runtime error
Runtime error
| """RoboFlamingo with LSTM Policy Head!""" | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import sys | |
| sys.path.insert(0, '/home/user/app') | |
| print("π€ Loading RoboFlamingo with Policy Head") | |
| exec(open('/home/user/app/patch_flamingo.py').read()) | |
| MODEL_LOADED = False | |
| model = None | |
| image_processor = None | |
| tokenizer = None | |
| device = "cpu" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| from patched_factory import create_model_and_transforms | |
| print("π¨ Creating RoboFlamingo with LSTM policy head...") | |
| model, image_processor, tokenizer = create_model_and_transforms(checkpoint_path=True) | |
| model.to(device).eval() | |
| MODEL_LOADED = True | |
| print("=" * 70) | |
| print("β ROBOFLAMINGO WITH POLICY HEAD READY!") | |
| print("=" * 70) | |
| except Exception as e: | |
| print(f"β {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def plot_traj(acts): | |
| fig = plt.figure(figsize=(10,8)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| x = np.cumsum([a['delta_x'] for a in acts]) | |
| y = np.cumsum([a['delta_y'] for a in acts]) | |
| z = np.cumsum([a['delta_z'] for a in acts]) | |
| ax.plot(x, y, z, 'b-', lw=2, marker='o', ms=6) | |
| ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start') | |
| ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End') | |
| ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z') | |
| ax.set_title('Trajectory'); ax.legend(); ax.grid() | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', dpi=100) | |
| buf.seek(0); plt.close() | |
| return Image.open(buf) | |
| def plot_grip(grip): | |
| fig, ax = plt.subplots(figsize=(12,3)) | |
| cols = ['green' if g==0 else 'red' for g in grip] | |
| ax.bar(range(len(grip)), [1]*len(grip), color=cols, alpha=0.7, ec='black') | |
| for i, g in enumerate(grip): | |
| ax.text(i, 0.5, 'OPEN' if g==0 else 'CLOSE', ha='center', va='center', weight='bold') | |
| ax.set_xlabel('Step'); ax.set_ylim(0,1.2); ax.grid(alpha=0.3) | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', dpi=100) | |
| buf.seek(0); plt.close() | |
| return Image.open(buf) | |
| def predict(inst, img1, img2): | |
| if not MODEL_LOADED: | |
| return None, None, "", "β Model not loaded" | |
| if not inst or not inst.strip(): | |
| return None, None, "", "β Enter instruction" | |
| if img1 is None or img2 is None: | |
| return None, None, "", "β Upload both images" | |
| try: | |
| if isinstance(img1, np.ndarray): | |
| img1 = Image.fromarray(img1) | |
| if isinstance(img2, np.ndarray): | |
| img2 = Image.fromarray(img2) | |
| print(f"π€ {inst}") | |
| with torch.no_grad(): | |
| t1 = image_processor(img1).unsqueeze(0).to(device) | |
| t2 = image_processor(img2).unsqueeze(0).to(device) | |
| vis = torch.stack([t1, t2], dim=1).unsqueeze(2) | |
| tok = tokenizer(inst, return_tensors="pt", padding=True, | |
| truncation=True, max_length=512) | |
| lang_x = tok['input_ids'].to(device) | |
| attn_mask = tok.get('attention_mask') | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.bool().to(device) | |
| out = model(vision_x=vis, lang_x=lang_x, attention_mask=attn_mask) | |
| print(f" Output keys: {out.keys() if isinstance(out, dict) else 'not dict'}") | |
| if not isinstance(out, dict) or 'actions' not in out: | |
| return None, None, "", f"β Unexpected output format: {type(out)}" | |
| actions = out['actions'] | |
| gripper = out['gripper'] | |
| print(f" Actions shape: {actions.shape}") | |
| print(f" Gripper shape: {gripper.shape}") | |
| actions_np = actions[0].cpu().numpy() | |
| gripper_np = gripper[0].cpu().numpy() | |
| acts = [] | |
| for t, ac in enumerate(actions_np): | |
| if len(ac) < 7: | |
| ac = np.pad(ac, (0, 7-len(ac))) | |
| acts.append({ | |
| 'timestep': t, | |
| 'delta_x': float(ac[0]), | |
| 'delta_y': float(ac[1]), | |
| 'delta_z': float(ac[2]), | |
| 'qw': float(ac[3]), | |
| 'qx': float(ac[4]), | |
| 'qy': float(ac[5]), | |
| 'qz': float(ac[6]) | |
| }) | |
| # Parse gripper with adaptive normalization | |
| gripper_np_flat = gripper_np.flatten() | |
| # Debug: Print gripper statistics | |
| print(f" π Gripper stats:") | |
| print(f" min={gripper_np_flat.min():.4f}, max={gripper_np_flat.max():.4f}") | |
| print(f" mean={gripper_np_flat.mean():.4f}, std={gripper_np_flat.std():.4f}") | |
| print(f" sample: {gripper_np_flat[:5]}") | |
| # Normalize gripper values if they have variation | |
| g_range = gripper_np_flat.max() - gripper_np_flat.min() | |
| if g_range > 0.1: | |
| # Significant variation: Normalize to [0, 1] range | |
| gripper_norm = (gripper_np_flat - gripper_np_flat.min()) / g_range | |
| grip = [int(g > 0.5) for g in gripper_norm] | |
| print(f" β Normalized (range={g_range:.4f})") | |
| elif g_range > 0.01: | |
| # Small variation: Use adaptive median threshold | |
| threshold = np.median(gripper_np_flat) | |
| grip = [int(g > threshold) for g in gripper_np_flat] | |
| print(f" β Adaptive threshold={threshold:.4f}") | |
| else: | |
| # Almost no variation | |
| mean_val = gripper_np_flat.mean() | |
| if mean_val > 0.7: | |
| grip = [int(g > 0.9) for g in gripper_np_flat] | |
| print(f" β High values (mean={mean_val:.4f}), threshold=0.9") | |
| elif mean_val < 0.3: | |
| grip = [int(g > 0.1) for g in gripper_np_flat] | |
| print(f" β Low values (mean={mean_val:.4f}), threshold=0.1") | |
| else: | |
| grip = [0] * len(gripper_np_flat) | |
| print(f" β No variation, defaulting to OPEN") | |
| print(f" π€ Result: {sum(grip)}/{len(grip)} CLOSED") | |
| traj_plot = plot_traj(acts) | |
| grip_plot = plot_grip(grip) | |
| table = "| T | Ξx | Ξy | Ξz | Grip |\n|--|--|--|--|--|\n" | |
| for i, a in enumerate(acts[:12]): | |
| g_state = "CLOSE" if i < len(grip) and grip[i] == 1 else "OPEN" | |
| table += f"| {a['timestep']:2d} | {a['delta_x']:6.3f} | " | |
| table += f"{a['delta_y']:6.3f} | {a['delta_z']:6.3f} | {g_state} |\n" | |
| status = f"β SUCCESS\n{inst}\n{len(acts)} timesteps" | |
| print(f" β Predicted {len(acts)} robot actions!") | |
| return traj_plot, grip_plot, table, status | |
| except Exception as e: | |
| print(f"β {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, "", f"β {str(e)[:300]}" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| status_text = "π’ WITH POLICY HEAD" if MODEL_LOADED else "π΄ NOT LOADED" | |
| gr.Markdown(f"""# π€ RoboFlamingo - {status_text} | |
| {'β **LSTM Policy Head attached for robot action prediction!**' if MODEL_LOADED else 'β Model not loaded'} | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| instruction = gr.Textbox(label="Instruction", | |
| placeholder="pick up the red block", lines=3) | |
| with gr.Row(): | |
| img_third = gr.Image(label="Third-Person", type="pil", height=250) | |
| img_grip = gr.Image(label="Gripper", type="pil", height=250) | |
| predict_btn = gr.Button("π€ Predict", variant="primary", size="lg") | |
| status_box = gr.Textbox(label="Status", lines=6, interactive=False) | |
| with gr.Column(): | |
| traj_output = gr.Image(label="7-DOF Trajectory", type="pil") | |
| grip_output = gr.Image(label="Gripper Commands", type="pil") | |
| table_output = gr.Markdown() | |
| predict_btn.click(predict, [instruction, img_third, img_grip], | |
| [traj_output, grip_output, table_output, status_box]) | |
| gr.Markdown("[Paper](https://arxiv.org/abs/2311.01378)") | |
| demo.launch() | |