Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from datetime import datetime
|
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import torch
|
| 14 |
-
import spaces
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from e2b_desktop import Sandbox
|
| 17 |
from gradio_modal import Modal
|
|
@@ -58,7 +58,12 @@ if not os.path.exists(TMP_DIR):
|
|
| 58 |
|
| 59 |
print("Loading Fara Model... This may take a moment.")
|
| 60 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
try:
|
| 64 |
processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
|
|
@@ -73,7 +78,6 @@ except Exception as e:
|
|
| 73 |
print(f"Error loading Fara Model: {e}")
|
| 74 |
print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct for demonstration if Fara is unavailable...")
|
| 75 |
try:
|
| 76 |
-
# Fallback to base Qwen-VL if Fara repo isn't public/accessible
|
| 77 |
MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 78 |
processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
|
| 79 |
model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
@@ -85,17 +89,67 @@ except Exception as e:
|
|
| 85 |
print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.")
|
| 86 |
except Exception as inner_e:
|
| 87 |
print(f"Critical error loading model: {inner_e}")
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
class FaraLocalModel(Model):
|
| 92 |
"""
|
| 93 |
Wrapper for the local Fara (Qwen2.5-VL) model to work with SmolAgents.
|
| 94 |
"""
|
| 95 |
-
def __init__(self,
|
| 96 |
super().__init__(**kwargs)
|
| 97 |
-
self.model = model
|
| 98 |
-
self.processor = processor
|
| 99 |
|
| 100 |
def __call__(
|
| 101 |
self,
|
|
@@ -103,12 +157,11 @@ class FaraLocalModel(Model):
|
|
| 103 |
stop_sequences: Optional[List[str]] = None,
|
| 104 |
**kwargs,
|
| 105 |
) -> ChatMessage:
|
| 106 |
-
|
| 107 |
-
raise ValueError("Fara Model is not loaded.")
|
| 108 |
-
|
| 109 |
formatted_messages = []
|
| 110 |
|
| 111 |
# Convert SmolAgents messages to Qwen/Transformers format
|
|
|
|
| 112 |
for msg in messages:
|
| 113 |
role = msg["role"]
|
| 114 |
content = msg["content"]
|
|
@@ -124,7 +177,7 @@ class FaraLocalModel(Model):
|
|
| 124 |
elif isinstance(item, dict):
|
| 125 |
if "type" in item:
|
| 126 |
if item["type"] == "image":
|
| 127 |
-
# Handle path or url
|
| 128 |
val = item.get("image") or item.get("url") or item.get("path")
|
| 129 |
new_content.append({"type": "image", "image": val})
|
| 130 |
else:
|
|
@@ -132,39 +185,14 @@ class FaraLocalModel(Model):
|
|
| 132 |
|
| 133 |
formatted_messages.append({"role": role, "content": new_content})
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
)
|
| 139 |
-
|
| 140 |
-
image_inputs, video_inputs = process_vision_info(formatted_messages)
|
| 141 |
-
|
| 142 |
-
inputs = self.processor(
|
| 143 |
-
text=[text],
|
| 144 |
-
images=image_inputs,
|
| 145 |
-
videos=video_inputs,
|
| 146 |
-
padding=True,
|
| 147 |
-
return_tensors="pt",
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
inputs = inputs.to(self.model.device)
|
| 151 |
-
|
| 152 |
-
# Generate
|
| 153 |
-
with torch.no_grad():
|
| 154 |
-
generated_ids = self.model.generate(
|
| 155 |
-
**inputs,
|
| 156 |
-
max_new_tokens=kwargs.get("max_tokens", 1024),
|
| 157 |
-
stop_strings=stop_sequences,
|
| 158 |
-
tokenizer=self.processor.tokenizer,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Decode
|
| 162 |
-
generated_ids_trimmed = [
|
| 163 |
-
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 164 |
-
]
|
| 165 |
-
output_text = self.processor.batch_decode(
|
| 166 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 167 |
-
)[0]
|
| 168 |
|
| 169 |
return ChatMessage(
|
| 170 |
role=MessageRole.ASSISTANT,
|
|
@@ -615,11 +643,8 @@ def save_final_status(folder, status: str, summary, error_message=None) -> None:
|
|
| 615 |
print(f"Failed to save metadata: {e}")
|
| 616 |
|
| 617 |
def create_agent(data_dir, desktop):
|
| 618 |
-
#
|
| 619 |
-
|
| 620 |
-
raise RuntimeError("Fara model was not loaded successfully.")
|
| 621 |
-
|
| 622 |
-
model = FaraLocalModel(model=model_f, processor=processor_f)
|
| 623 |
|
| 624 |
return E2BVisionAgent(
|
| 625 |
model=model,
|
|
@@ -755,7 +780,7 @@ def initialize_session(interactive_mode, browser_uuid):
|
|
| 755 |
return update_html(interactive_mode, browser_uuid), browser_uuid
|
| 756 |
|
| 757 |
class EnrichedGradioUI(GradioUI):
|
| 758 |
-
@spaces.GPU
|
| 759 |
def interact_with_agent(
|
| 760 |
self,
|
| 761 |
task_input,
|
|
@@ -772,31 +797,34 @@ class EnrichedGradioUI(GradioUI):
|
|
| 772 |
if not os.path.exists(data_dir):
|
| 773 |
os.makedirs(data_dir)
|
| 774 |
|
| 775 |
-
#
|
| 776 |
-
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
try:
|
| 779 |
stored_messages.append(gr.ChatMessage(role="user", content=task_input))
|
| 780 |
yield stored_messages
|
| 781 |
|
| 782 |
-
screenshot_bytes =
|
| 783 |
initial_screenshot = Image.open(BytesIO(screenshot_bytes))
|
| 784 |
|
| 785 |
for msg in stream_to_gradio(
|
| 786 |
-
|
| 787 |
task=task_input,
|
| 788 |
task_images=[initial_screenshot],
|
| 789 |
reset_agent_memory=False,
|
| 790 |
):
|
| 791 |
if (
|
| 792 |
-
hasattr(
|
| 793 |
and msg.content == "-----"
|
| 794 |
):
|
| 795 |
stored_messages.append(
|
| 796 |
gr.ChatMessage(
|
| 797 |
role="assistant",
|
| 798 |
content={
|
| 799 |
-
"path":
|
| 800 |
"mime_type": "image/png",
|
| 801 |
},
|
| 802 |
)
|
|
@@ -805,7 +833,7 @@ class EnrichedGradioUI(GradioUI):
|
|
| 805 |
yield stored_messages
|
| 806 |
|
| 807 |
if consent_storage:
|
| 808 |
-
summary = get_agent_summary_erase_images(
|
| 809 |
save_final_status(data_dir, "completed", summary=summary)
|
| 810 |
yield stored_messages
|
| 811 |
|
|
@@ -891,7 +919,7 @@ This agent uses **microsoft/Fara-7B** (running locally via ZeroGPU) and **smolag
|
|
| 891 |
return update_html(True, session_uuid)
|
| 892 |
|
| 893 |
def interrupt_agent(session_state):
|
| 894 |
-
if "agent" in session_state and not session_state["agent"].interrupt_switch:
|
| 895 |
session_state["agent"].interrupt()
|
| 896 |
return "Stopped"
|
| 897 |
return "Stop"
|
|
|
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import torch
|
| 14 |
+
import spaces
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from e2b_desktop import Sandbox
|
| 17 |
from gradio_modal import Modal
|
|
|
|
| 58 |
|
| 59 |
print("Loading Fara Model... This may take a moment.")
|
| 60 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
+
# Using the Microsoft Fara model as requested
|
| 62 |
+
MODEL_ID_F = "microsoft/Fara-7B"
|
| 63 |
+
|
| 64 |
+
# Global model variables
|
| 65 |
+
model_f = None
|
| 66 |
+
processor_f = None
|
| 67 |
|
| 68 |
try:
|
| 69 |
processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
|
|
|
|
| 78 |
print(f"Error loading Fara Model: {e}")
|
| 79 |
print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct for demonstration if Fara is unavailable...")
|
| 80 |
try:
|
|
|
|
| 81 |
MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 82 |
processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
|
| 83 |
model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
|
|
| 89 |
print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.")
|
| 90 |
except Exception as inner_e:
|
| 91 |
print(f"Critical error loading model: {inner_e}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------------
|
| 95 |
+
# GPU ISOLATED INFERENCE FUNCTION
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
@spaces.GPU(duration=120)
|
| 99 |
+
def run_model_inference(formatted_messages, max_tokens=1024, stop_sequences=None):
|
| 100 |
+
"""
|
| 101 |
+
This function runs on the GPU worker.
|
| 102 |
+
It receives simple python objects (lists/dicts), not the complex Agent object.
|
| 103 |
+
"""
|
| 104 |
+
global model_f, processor_f
|
| 105 |
+
|
| 106 |
+
if model_f is None:
|
| 107 |
+
raise ValueError("Model is not loaded.")
|
| 108 |
+
|
| 109 |
+
# Process Inputs (Tokenization happens here to ensure tensors are on correct device)
|
| 110 |
+
text = processor_f.apply_chat_template(
|
| 111 |
+
formatted_messages, tokenize=False, add_generation_prompt=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
image_inputs, video_inputs = process_vision_info(formatted_messages)
|
| 115 |
+
|
| 116 |
+
inputs = processor_f(
|
| 117 |
+
text=[text],
|
| 118 |
+
images=image_inputs,
|
| 119 |
+
videos=video_inputs,
|
| 120 |
+
padding=True,
|
| 121 |
+
return_tensors="pt",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Move inputs to the model's device (GPU)
|
| 125 |
+
inputs = inputs.to(model_f.device)
|
| 126 |
+
|
| 127 |
+
# Generate
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
generated_ids = model_f.generate(
|
| 130 |
+
**inputs,
|
| 131 |
+
max_new_tokens=max_tokens,
|
| 132 |
+
stop_strings=stop_sequences,
|
| 133 |
+
tokenizer=processor_f.tokenizer,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Decode
|
| 137 |
+
generated_ids_trimmed = [
|
| 138 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 139 |
+
]
|
| 140 |
+
output_text = processor_f.batch_decode(
|
| 141 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 142 |
+
)[0]
|
| 143 |
+
|
| 144 |
+
return output_text
|
| 145 |
+
|
| 146 |
|
| 147 |
class FaraLocalModel(Model):
|
| 148 |
"""
|
| 149 |
Wrapper for the local Fara (Qwen2.5-VL) model to work with SmolAgents.
|
| 150 |
"""
|
| 151 |
+
def __init__(self, **kwargs):
|
| 152 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
| 153 |
|
| 154 |
def __call__(
|
| 155 |
self,
|
|
|
|
| 157 |
stop_sequences: Optional[List[str]] = None,
|
| 158 |
**kwargs,
|
| 159 |
) -> ChatMessage:
|
| 160 |
+
|
|
|
|
|
|
|
| 161 |
formatted_messages = []
|
| 162 |
|
| 163 |
# Convert SmolAgents messages to Qwen/Transformers format
|
| 164 |
+
# We perform this conversion here (CPU side) to create simple dicts/lists
|
| 165 |
for msg in messages:
|
| 166 |
role = msg["role"]
|
| 167 |
content = msg["content"]
|
|
|
|
| 177 |
elif isinstance(item, dict):
|
| 178 |
if "type" in item:
|
| 179 |
if item["type"] == "image":
|
| 180 |
+
# Handle path or url - extract value to ensure serializability
|
| 181 |
val = item.get("image") or item.get("url") or item.get("path")
|
| 182 |
new_content.append({"type": "image", "image": val})
|
| 183 |
else:
|
|
|
|
| 185 |
|
| 186 |
formatted_messages.append({"role": role, "content": new_content})
|
| 187 |
|
| 188 |
+
# Call the decorated global function
|
| 189 |
+
# This crosses the boundary to the GPU worker safely because
|
| 190 |
+
# formatted_messages contains only standard Python types (str, list, dict, PIL.Image)
|
| 191 |
+
output_text = run_model_inference(
|
| 192 |
+
formatted_messages=formatted_messages,
|
| 193 |
+
max_tokens=kwargs.get("max_tokens", 1024),
|
| 194 |
+
stop_sequences=stop_sequences
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
return ChatMessage(
|
| 198 |
role=MessageRole.ASSISTANT,
|
|
|
|
| 643 |
print(f"Failed to save metadata: {e}")
|
| 644 |
|
| 645 |
def create_agent(data_dir, desktop):
|
| 646 |
+
# Initialize the wrapper that calls the global GPU function
|
| 647 |
+
model = FaraLocalModel()
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
return E2BVisionAgent(
|
| 650 |
model=model,
|
|
|
|
| 780 |
return update_html(interactive_mode, browser_uuid), browser_uuid
|
| 781 |
|
| 782 |
class EnrichedGradioUI(GradioUI):
|
| 783 |
+
# REMOVED @spaces.GPU from here to prevent pickling the E2B Sandbox (which has locks)
|
| 784 |
def interact_with_agent(
|
| 785 |
self,
|
| 786 |
task_input,
|
|
|
|
| 797 |
if not os.path.exists(data_dir):
|
| 798 |
os.makedirs(data_dir)
|
| 799 |
|
| 800 |
+
# Create fresh agent.
|
| 801 |
+
# Note: We do NOT store the full agent in session_state passed between Gradio events
|
| 802 |
+
# if possible, or if we do, we ensure this function isn't wrapped in @spaces.GPU
|
| 803 |
+
agent = create_agent(data_dir=data_dir, desktop=desktop)
|
| 804 |
+
session_state["agent"] = agent # Storing in state is fine if this function runs on CPU
|
| 805 |
|
| 806 |
try:
|
| 807 |
stored_messages.append(gr.ChatMessage(role="user", content=task_input))
|
| 808 |
yield stored_messages
|
| 809 |
|
| 810 |
+
screenshot_bytes = agent.desktop.screenshot(format="bytes")
|
| 811 |
initial_screenshot = Image.open(BytesIO(screenshot_bytes))
|
| 812 |
|
| 813 |
for msg in stream_to_gradio(
|
| 814 |
+
agent,
|
| 815 |
task=task_input,
|
| 816 |
task_images=[initial_screenshot],
|
| 817 |
reset_agent_memory=False,
|
| 818 |
):
|
| 819 |
if (
|
| 820 |
+
hasattr(agent, "last_marked_screenshot")
|
| 821 |
and msg.content == "-----"
|
| 822 |
):
|
| 823 |
stored_messages.append(
|
| 824 |
gr.ChatMessage(
|
| 825 |
role="assistant",
|
| 826 |
content={
|
| 827 |
+
"path": agent.last_marked_screenshot.to_string(),
|
| 828 |
"mime_type": "image/png",
|
| 829 |
},
|
| 830 |
)
|
|
|
|
| 833 |
yield stored_messages
|
| 834 |
|
| 835 |
if consent_storage:
|
| 836 |
+
summary = get_agent_summary_erase_images(agent)
|
| 837 |
save_final_status(data_dir, "completed", summary=summary)
|
| 838 |
yield stored_messages
|
| 839 |
|
|
|
|
| 919 |
return update_html(True, session_uuid)
|
| 920 |
|
| 921 |
def interrupt_agent(session_state):
|
| 922 |
+
if "agent" in session_state and hasattr(session_state["agent"], "interrupt_switch") and not session_state["agent"].interrupt_switch:
|
| 923 |
session_state["agent"].interrupt()
|
| 924 |
return "Stopped"
|
| 925 |
return "Stop"
|