aw1app
Fix: Adaptive gripper threshold with debug logging
40494ef
"""RoboFlamingo with LSTM Policy Head!"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from io import BytesIO
import sys
sys.path.insert(0, '/home/user/app')
print("πŸ€– Loading RoboFlamingo with Policy Head")
exec(open('/home/user/app/patch_flamingo.py').read())
MODEL_LOADED = False
model = None
image_processor = None
tokenizer = None
device = "cpu"
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
from patched_factory import create_model_and_transforms
print("πŸ”¨ Creating RoboFlamingo with LSTM policy head...")
model, image_processor, tokenizer = create_model_and_transforms(checkpoint_path=True)
model.to(device).eval()
MODEL_LOADED = True
print("=" * 70)
print("βœ… ROBOFLAMINGO WITH POLICY HEAD READY!")
print("=" * 70)
except Exception as e:
print(f"❌ {e}")
import traceback
traceback.print_exc()
def plot_traj(acts):
fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
x = np.cumsum([a['delta_x'] for a in acts])
y = np.cumsum([a['delta_y'] for a in acts])
z = np.cumsum([a['delta_z'] for a in acts])
ax.plot(x, y, z, 'b-', lw=2, marker='o', ms=6)
ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start')
ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End')
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_title('Trajectory'); ax.legend(); ax.grid()
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0); plt.close()
return Image.open(buf)
def plot_grip(grip):
fig, ax = plt.subplots(figsize=(12,3))
cols = ['green' if g==0 else 'red' for g in grip]
ax.bar(range(len(grip)), [1]*len(grip), color=cols, alpha=0.7, ec='black')
for i, g in enumerate(grip):
ax.text(i, 0.5, 'OPEN' if g==0 else 'CLOSE', ha='center', va='center', weight='bold')
ax.set_xlabel('Step'); ax.set_ylim(0,1.2); ax.grid(alpha=0.3)
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0); plt.close()
return Image.open(buf)
def predict(inst, img1, img2):
if not MODEL_LOADED:
return None, None, "", "❌ Model not loaded"
if not inst or not inst.strip():
return None, None, "", "❌ Enter instruction"
if img1 is None or img2 is None:
return None, None, "", "❌ Upload both images"
try:
if isinstance(img1, np.ndarray):
img1 = Image.fromarray(img1)
if isinstance(img2, np.ndarray):
img2 = Image.fromarray(img2)
print(f"πŸ€– {inst}")
with torch.no_grad():
t1 = image_processor(img1).unsqueeze(0).to(device)
t2 = image_processor(img2).unsqueeze(0).to(device)
vis = torch.stack([t1, t2], dim=1).unsqueeze(2)
tok = tokenizer(inst, return_tensors="pt", padding=True,
truncation=True, max_length=512)
lang_x = tok['input_ids'].to(device)
attn_mask = tok.get('attention_mask')
if attn_mask is not None:
attn_mask = attn_mask.bool().to(device)
out = model(vision_x=vis, lang_x=lang_x, attention_mask=attn_mask)
print(f" Output keys: {out.keys() if isinstance(out, dict) else 'not dict'}")
if not isinstance(out, dict) or 'actions' not in out:
return None, None, "", f"❌ Unexpected output format: {type(out)}"
actions = out['actions']
gripper = out['gripper']
print(f" Actions shape: {actions.shape}")
print(f" Gripper shape: {gripper.shape}")
actions_np = actions[0].cpu().numpy()
gripper_np = gripper[0].cpu().numpy()
acts = []
for t, ac in enumerate(actions_np):
if len(ac) < 7:
ac = np.pad(ac, (0, 7-len(ac)))
acts.append({
'timestep': t,
'delta_x': float(ac[0]),
'delta_y': float(ac[1]),
'delta_z': float(ac[2]),
'qw': float(ac[3]),
'qx': float(ac[4]),
'qy': float(ac[5]),
'qz': float(ac[6])
})
# Parse gripper with adaptive normalization
gripper_np_flat = gripper_np.flatten()
# Debug: Print gripper statistics
print(f" πŸ“Š Gripper stats:")
print(f" min={gripper_np_flat.min():.4f}, max={gripper_np_flat.max():.4f}")
print(f" mean={gripper_np_flat.mean():.4f}, std={gripper_np_flat.std():.4f}")
print(f" sample: {gripper_np_flat[:5]}")
# Normalize gripper values if they have variation
g_range = gripper_np_flat.max() - gripper_np_flat.min()
if g_range > 0.1:
# Significant variation: Normalize to [0, 1] range
gripper_norm = (gripper_np_flat - gripper_np_flat.min()) / g_range
grip = [int(g > 0.5) for g in gripper_norm]
print(f" βœ“ Normalized (range={g_range:.4f})")
elif g_range > 0.01:
# Small variation: Use adaptive median threshold
threshold = np.median(gripper_np_flat)
grip = [int(g > threshold) for g in gripper_np_flat]
print(f" βœ“ Adaptive threshold={threshold:.4f}")
else:
# Almost no variation
mean_val = gripper_np_flat.mean()
if mean_val > 0.7:
grip = [int(g > 0.9) for g in gripper_np_flat]
print(f" βœ“ High values (mean={mean_val:.4f}), threshold=0.9")
elif mean_val < 0.3:
grip = [int(g > 0.1) for g in gripper_np_flat]
print(f" βœ“ Low values (mean={mean_val:.4f}), threshold=0.1")
else:
grip = [0] * len(gripper_np_flat)
print(f" ⚠ No variation, defaulting to OPEN")
print(f" πŸ€– Result: {sum(grip)}/{len(grip)} CLOSED")
traj_plot = plot_traj(acts)
grip_plot = plot_grip(grip)
table = "| T | Ξ”x | Ξ”y | Ξ”z | Grip |\n|--|--|--|--|--|\n"
for i, a in enumerate(acts[:12]):
g_state = "CLOSE" if i < len(grip) and grip[i] == 1 else "OPEN"
table += f"| {a['timestep']:2d} | {a['delta_x']:6.3f} | "
table += f"{a['delta_y']:6.3f} | {a['delta_z']:6.3f} | {g_state} |\n"
status = f"βœ… SUCCESS\n{inst}\n{len(acts)} timesteps"
print(f" βœ… Predicted {len(acts)} robot actions!")
return traj_plot, grip_plot, table, status
except Exception as e:
print(f"❌ {e}")
import traceback
traceback.print_exc()
return None, None, "", f"❌ {str(e)[:300]}"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
status_text = "🟒 WITH POLICY HEAD" if MODEL_LOADED else "πŸ”΄ NOT LOADED"
gr.Markdown(f"""# πŸ€– RoboFlamingo - {status_text}
{'βœ… **LSTM Policy Head attached for robot action prediction!**' if MODEL_LOADED else '❌ Model not loaded'}
""")
with gr.Row():
with gr.Column():
instruction = gr.Textbox(label="Instruction",
placeholder="pick up the red block", lines=3)
with gr.Row():
img_third = gr.Image(label="Third-Person", type="pil", height=250)
img_grip = gr.Image(label="Gripper", type="pil", height=250)
predict_btn = gr.Button("πŸ€– Predict", variant="primary", size="lg")
status_box = gr.Textbox(label="Status", lines=6, interactive=False)
with gr.Column():
traj_output = gr.Image(label="7-DOF Trajectory", type="pil")
grip_output = gr.Image(label="Gripper Commands", type="pil")
table_output = gr.Markdown()
predict_btn.click(predict, [instruction, img_third, img_grip],
[traj_output, grip_output, table_output, status_box])
gr.Markdown("[Paper](https://arxiv.org/abs/2311.01378)")
demo.launch()