prithivMLmods commited on
Commit
6308c59
·
verified ·
1 Parent(s): 79a20f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -58
app.py CHANGED
@@ -11,7 +11,7 @@ from datetime import datetime
11
 
12
  import gradio as gr
13
  import torch
14
- import spaces # <--- Added Spaces support
15
  from dotenv import load_dotenv
16
  from e2b_desktop import Sandbox
17
  from gradio_modal import Modal
@@ -58,7 +58,12 @@ if not os.path.exists(TMP_DIR):
58
 
59
  print("Loading Fara Model... This may take a moment.")
60
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
61
- MODEL_ID_F = "microsoft/Fara-7B" # Ensure this repository exists and you have access
 
 
 
 
 
62
 
63
  try:
64
  processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
@@ -73,7 +78,6 @@ except Exception as e:
73
  print(f"Error loading Fara Model: {e}")
74
  print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct for demonstration if Fara is unavailable...")
75
  try:
76
- # Fallback to base Qwen-VL if Fara repo isn't public/accessible
77
  MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct"
78
  processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
79
  model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -85,17 +89,67 @@ except Exception as e:
85
  print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.")
86
  except Exception as inner_e:
87
  print(f"Critical error loading model: {inner_e}")
88
- model_f = None
89
- processor_f = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  class FaraLocalModel(Model):
92
  """
93
  Wrapper for the local Fara (Qwen2.5-VL) model to work with SmolAgents.
94
  """
95
- def __init__(self, model, processor, **kwargs):
96
  super().__init__(**kwargs)
97
- self.model = model
98
- self.processor = processor
99
 
100
  def __call__(
101
  self,
@@ -103,12 +157,11 @@ class FaraLocalModel(Model):
103
  stop_sequences: Optional[List[str]] = None,
104
  **kwargs,
105
  ) -> ChatMessage:
106
- if self.model is None:
107
- raise ValueError("Fara Model is not loaded.")
108
-
109
  formatted_messages = []
110
 
111
  # Convert SmolAgents messages to Qwen/Transformers format
 
112
  for msg in messages:
113
  role = msg["role"]
114
  content = msg["content"]
@@ -124,7 +177,7 @@ class FaraLocalModel(Model):
124
  elif isinstance(item, dict):
125
  if "type" in item:
126
  if item["type"] == "image":
127
- # Handle path or url
128
  val = item.get("image") or item.get("url") or item.get("path")
129
  new_content.append({"type": "image", "image": val})
130
  else:
@@ -132,39 +185,14 @@ class FaraLocalModel(Model):
132
 
133
  formatted_messages.append({"role": role, "content": new_content})
134
 
135
- # Process Inputs
136
- text = self.processor.apply_chat_template(
137
- formatted_messages, tokenize=False, add_generation_prompt=True
 
 
 
 
138
  )
139
-
140
- image_inputs, video_inputs = process_vision_info(formatted_messages)
141
-
142
- inputs = self.processor(
143
- text=[text],
144
- images=image_inputs,
145
- videos=video_inputs,
146
- padding=True,
147
- return_tensors="pt",
148
- )
149
-
150
- inputs = inputs.to(self.model.device)
151
-
152
- # Generate
153
- with torch.no_grad():
154
- generated_ids = self.model.generate(
155
- **inputs,
156
- max_new_tokens=kwargs.get("max_tokens", 1024),
157
- stop_strings=stop_sequences,
158
- tokenizer=self.processor.tokenizer,
159
- )
160
-
161
- # Decode
162
- generated_ids_trimmed = [
163
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
164
- ]
165
- output_text = self.processor.batch_decode(
166
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
167
- )[0]
168
 
169
  return ChatMessage(
170
  role=MessageRole.ASSISTANT,
@@ -615,11 +643,8 @@ def save_final_status(folder, status: str, summary, error_message=None) -> None:
615
  print(f"Failed to save metadata: {e}")
616
 
617
  def create_agent(data_dir, desktop):
618
- # Instantiate the local model wrapper
619
- if model_f is None:
620
- raise RuntimeError("Fara model was not loaded successfully.")
621
-
622
- model = FaraLocalModel(model=model_f, processor=processor_f)
623
 
624
  return E2BVisionAgent(
625
  model=model,
@@ -755,7 +780,7 @@ def initialize_session(interactive_mode, browser_uuid):
755
  return update_html(interactive_mode, browser_uuid), browser_uuid
756
 
757
  class EnrichedGradioUI(GradioUI):
758
- @spaces.GPU(duration=180) # Allocate GPU for 3 minutes per interaction cycle
759
  def interact_with_agent(
760
  self,
761
  task_input,
@@ -772,31 +797,34 @@ class EnrichedGradioUI(GradioUI):
772
  if not os.path.exists(data_dir):
773
  os.makedirs(data_dir)
774
 
775
- # Re-create agent to ensure fresh context with the Fara model
776
- session_state["agent"] = create_agent(data_dir=data_dir, desktop=desktop)
 
 
 
777
 
778
  try:
779
  stored_messages.append(gr.ChatMessage(role="user", content=task_input))
780
  yield stored_messages
781
 
782
- screenshot_bytes = session_state["agent"].desktop.screenshot(format="bytes")
783
  initial_screenshot = Image.open(BytesIO(screenshot_bytes))
784
 
785
  for msg in stream_to_gradio(
786
- session_state["agent"],
787
  task=task_input,
788
  task_images=[initial_screenshot],
789
  reset_agent_memory=False,
790
  ):
791
  if (
792
- hasattr(session_state["agent"], "last_marked_screenshot")
793
  and msg.content == "-----"
794
  ):
795
  stored_messages.append(
796
  gr.ChatMessage(
797
  role="assistant",
798
  content={
799
- "path": session_state["agent"].last_marked_screenshot.to_string(),
800
  "mime_type": "image/png",
801
  },
802
  )
@@ -805,7 +833,7 @@ class EnrichedGradioUI(GradioUI):
805
  yield stored_messages
806
 
807
  if consent_storage:
808
- summary = get_agent_summary_erase_images(session_state["agent"])
809
  save_final_status(data_dir, "completed", summary=summary)
810
  yield stored_messages
811
 
@@ -891,7 +919,7 @@ This agent uses **microsoft/Fara-7B** (running locally via ZeroGPU) and **smolag
891
  return update_html(True, session_uuid)
892
 
893
  def interrupt_agent(session_state):
894
- if "agent" in session_state and not session_state["agent"].interrupt_switch:
895
  session_state["agent"].interrupt()
896
  return "Stopped"
897
  return "Stop"
 
11
 
12
  import gradio as gr
13
  import torch
14
+ import spaces
15
  from dotenv import load_dotenv
16
  from e2b_desktop import Sandbox
17
  from gradio_modal import Modal
 
58
 
59
  print("Loading Fara Model... This may take a moment.")
60
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
61
+ # Using the Microsoft Fara model as requested
62
+ MODEL_ID_F = "microsoft/Fara-7B"
63
+
64
+ # Global model variables
65
+ model_f = None
66
+ processor_f = None
67
 
68
  try:
69
  processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
 
78
  print(f"Error loading Fara Model: {e}")
79
  print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct for demonstration if Fara is unavailable...")
80
  try:
 
81
  MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct"
82
  processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
83
  model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
89
  print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.")
90
  except Exception as inner_e:
91
  print(f"Critical error loading model: {inner_e}")
92
+
93
+
94
+ # -----------------------------------------------------------------------------
95
+ # GPU ISOLATED INFERENCE FUNCTION
96
+ # -----------------------------------------------------------------------------
97
+
98
+ @spaces.GPU(duration=120)
99
+ def run_model_inference(formatted_messages, max_tokens=1024, stop_sequences=None):
100
+ """
101
+ This function runs on the GPU worker.
102
+ It receives simple python objects (lists/dicts), not the complex Agent object.
103
+ """
104
+ global model_f, processor_f
105
+
106
+ if model_f is None:
107
+ raise ValueError("Model is not loaded.")
108
+
109
+ # Process Inputs (Tokenization happens here to ensure tensors are on correct device)
110
+ text = processor_f.apply_chat_template(
111
+ formatted_messages, tokenize=False, add_generation_prompt=True
112
+ )
113
+
114
+ image_inputs, video_inputs = process_vision_info(formatted_messages)
115
+
116
+ inputs = processor_f(
117
+ text=[text],
118
+ images=image_inputs,
119
+ videos=video_inputs,
120
+ padding=True,
121
+ return_tensors="pt",
122
+ )
123
+
124
+ # Move inputs to the model's device (GPU)
125
+ inputs = inputs.to(model_f.device)
126
+
127
+ # Generate
128
+ with torch.no_grad():
129
+ generated_ids = model_f.generate(
130
+ **inputs,
131
+ max_new_tokens=max_tokens,
132
+ stop_strings=stop_sequences,
133
+ tokenizer=processor_f.tokenizer,
134
+ )
135
+
136
+ # Decode
137
+ generated_ids_trimmed = [
138
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
139
+ ]
140
+ output_text = processor_f.batch_decode(
141
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
142
+ )[0]
143
+
144
+ return output_text
145
+
146
 
147
  class FaraLocalModel(Model):
148
  """
149
  Wrapper for the local Fara (Qwen2.5-VL) model to work with SmolAgents.
150
  """
151
+ def __init__(self, **kwargs):
152
  super().__init__(**kwargs)
 
 
153
 
154
  def __call__(
155
  self,
 
157
  stop_sequences: Optional[List[str]] = None,
158
  **kwargs,
159
  ) -> ChatMessage:
160
+
 
 
161
  formatted_messages = []
162
 
163
  # Convert SmolAgents messages to Qwen/Transformers format
164
+ # We perform this conversion here (CPU side) to create simple dicts/lists
165
  for msg in messages:
166
  role = msg["role"]
167
  content = msg["content"]
 
177
  elif isinstance(item, dict):
178
  if "type" in item:
179
  if item["type"] == "image":
180
+ # Handle path or url - extract value to ensure serializability
181
  val = item.get("image") or item.get("url") or item.get("path")
182
  new_content.append({"type": "image", "image": val})
183
  else:
 
185
 
186
  formatted_messages.append({"role": role, "content": new_content})
187
 
188
+ # Call the decorated global function
189
+ # This crosses the boundary to the GPU worker safely because
190
+ # formatted_messages contains only standard Python types (str, list, dict, PIL.Image)
191
+ output_text = run_model_inference(
192
+ formatted_messages=formatted_messages,
193
+ max_tokens=kwargs.get("max_tokens", 1024),
194
+ stop_sequences=stop_sequences
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  return ChatMessage(
198
  role=MessageRole.ASSISTANT,
 
643
  print(f"Failed to save metadata: {e}")
644
 
645
  def create_agent(data_dir, desktop):
646
+ # Initialize the wrapper that calls the global GPU function
647
+ model = FaraLocalModel()
 
 
 
648
 
649
  return E2BVisionAgent(
650
  model=model,
 
780
  return update_html(interactive_mode, browser_uuid), browser_uuid
781
 
782
  class EnrichedGradioUI(GradioUI):
783
+ # REMOVED @spaces.GPU from here to prevent pickling the E2B Sandbox (which has locks)
784
  def interact_with_agent(
785
  self,
786
  task_input,
 
797
  if not os.path.exists(data_dir):
798
  os.makedirs(data_dir)
799
 
800
+ # Create fresh agent.
801
+ # Note: We do NOT store the full agent in session_state passed between Gradio events
802
+ # if possible, or if we do, we ensure this function isn't wrapped in @spaces.GPU
803
+ agent = create_agent(data_dir=data_dir, desktop=desktop)
804
+ session_state["agent"] = agent # Storing in state is fine if this function runs on CPU
805
 
806
  try:
807
  stored_messages.append(gr.ChatMessage(role="user", content=task_input))
808
  yield stored_messages
809
 
810
+ screenshot_bytes = agent.desktop.screenshot(format="bytes")
811
  initial_screenshot = Image.open(BytesIO(screenshot_bytes))
812
 
813
  for msg in stream_to_gradio(
814
+ agent,
815
  task=task_input,
816
  task_images=[initial_screenshot],
817
  reset_agent_memory=False,
818
  ):
819
  if (
820
+ hasattr(agent, "last_marked_screenshot")
821
  and msg.content == "-----"
822
  ):
823
  stored_messages.append(
824
  gr.ChatMessage(
825
  role="assistant",
826
  content={
827
+ "path": agent.last_marked_screenshot.to_string(),
828
  "mime_type": "image/png",
829
  },
830
  )
 
833
  yield stored_messages
834
 
835
  if consent_storage:
836
+ summary = get_agent_summary_erase_images(agent)
837
  save_final_status(data_dir, "completed", summary=summary)
838
  yield stored_messages
839
 
 
919
  return update_html(True, session_uuid)
920
 
921
  def interrupt_agent(session_state):
922
+ if "agent" in session_state and hasattr(session_state["agent"], "interrupt_switch") and not session_state["agent"].interrupt_switch:
923
  session_state["agent"].interrupt()
924
  return "Stopped"
925
  return "Stop"