Sairesh commited on
Commit
a5b70d6
·
verified ·
1 Parent(s): 7dc7fec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -62
main.py CHANGED
@@ -1,53 +1,86 @@
 
 
 
1
  import io
 
 
 
2
  import torch
3
  from fastapi import FastAPI, UploadFile, File
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
- from smolagents import CodeAgent, InferenceClientModel
7
-
8
- # =========================
9
- # FORCE SAFE ATTENTION
10
- # =========================
11
- torch.backends.cuda.enable_flash_sdp(False)
12
- torch.backends.cuda.enable_mem_efficient_sdp(False)
13
- torch.backends.cuda.enable_math_sdp(False)
14
 
15
- # =========================
16
- # APP
17
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
18
  app = FastAPI()
19
 
20
  device = "cpu"
21
- MODEL_ID = "microsoft/Florence-2-large"
22
 
23
  print("⏳ Loading Florence-2 (SAFE MODE)...")
24
 
25
- vision_model = AutoModelForCausalLM.from_pretrained(
26
- MODEL_ID,
27
- trust_remote_code=True,
28
- attn_implementation="eager" # 🔥 THIS FIXES YOUR ERROR
29
- ).to(device)
30
-
31
- processor = AutoProcessor.from_pretrained(
32
- MODEL_ID,
33
- trust_remote_code=True
34
- )
35
 
36
- print("✅ Florence-2 loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # =========================
39
- # STRATEGIST (CLOUD)
40
- # =========================
41
  strategist = CodeAgent(
42
  tools=[],
43
- model=InferenceClientModel(
44
- model_id="meta-llama/Llama-3.2-3B-Instruct"
45
- )
46
  )
47
 
48
- # =========================
49
- # ROUTES
50
- # =========================
51
  @app.get("/")
52
  def home():
53
  return {
@@ -58,42 +91,46 @@ def home():
58
 
59
  @app.post("/analyze/")
60
  async def analyze(file: UploadFile = File(...)):
61
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
62
- w, h = img.size
 
 
63
 
64
- inputs = processor(
65
- text="button, icon, start, attack, confirm, close, x, claim, menu",
66
- images=img,
67
- return_tensors="pt"
68
- ).to(device)
69
 
70
- with torch.no_grad():
71
- out = vision_model.generate(
72
- input_ids=inputs["input_ids"],
73
- pixel_values=inputs["pixel_values"],
74
- max_new_tokens=512
75
- )
76
 
77
- decoded = processor.batch_decode(out, skip_special_tokens=False)[0]
 
 
 
 
 
78
 
79
- vision = processor.post_process_generation(
80
- decoded,
81
- task="<CAPTION_TO_PHRASE_GROUNDING>",
82
- image_size=(w, h)
83
- )
 
 
84
 
85
- prompt = f"""
86
- You are a mobile bot.
87
- Resolution: 720x1600
88
- Detected UI: {vision}
89
 
90
- ONLY output:
91
- tap X Y
 
92
  """
93
 
94
- decision = strategist.run(prompt)
95
 
96
- return {
97
- "decision": str(decision).strip(),
98
- "vision": vision
99
- }
 
 
 
 
 
1
+ # main.py - Replace the whole file with this exact code
2
+
3
+ import os
4
  import io
5
+ import traceback
6
+
7
+ # --- Basic imports
8
  import torch
9
  from fastapi import FastAPI, UploadFile, File
10
  from PIL import Image
11
+ import transformers
12
+
13
+ # --- PATCH: remove/ignore 'flash_attn' import requirement from remote modeling code
14
+ # This prevents the dynamic import checker from forcing flash_attn installation.
15
+ try:
16
+ from transformers import dynamic_module_utils
17
+ _orig_get_imports = dynamic_module_utils.get_imports
18
+
19
+ def _patched_get_imports(filename):
20
+ imports = _orig_get_imports(filename)
21
+ # remove problematic optional libs that cause the HF dynamic checker to abort
22
+ filtered = [imp for imp in imports if "flash_attn" not in imp and "xformers" not in imp]
23
+ return filtered
24
+
25
+ dynamic_module_utils.get_imports = _patched_get_imports
26
+ except Exception:
27
+ # If patching fails, continue; downstream code will try to load models and may raise clearer errors.
28
+ pass
29
+
30
+ # Now import model helpers from transformers (after patch above)
31
  from transformers import AutoProcessor, AutoModelForCausalLM
 
 
 
 
 
 
 
 
32
 
33
+ # --- Safety: try to disable specialized SDPA/flash settings if present
34
+ try:
35
+ # these calls exist only when built with CUDA-enabled torch backends; wrap in try/except
36
+ if hasattr(torch.backends, "cuda"):
37
+ if hasattr(torch.backends.cuda, "enable_flash_sdp"):
38
+ torch.backends.cuda.enable_flash_sdp(False)
39
+ if hasattr(torch.backends.cuda, "enable_mem_efficient_sdp"):
40
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
41
+ if hasattr(torch.backends.cuda, "enable_math_sdp"):
42
+ torch.backends.cuda.enable_math_sdp(False)
43
+ except Exception:
44
+ # ignore backend toggling errors on CPU-only environments
45
+ pass
46
+
47
+ # --- App setup
48
  app = FastAPI()
49
 
50
  device = "cpu"
51
+ VISION_MODEL_ID = "microsoft/Florence-2-large"
52
 
53
  print("⏳ Loading Florence-2 (SAFE MODE)...")
54
 
55
+ try:
56
+ # Force legacy attention mode to avoid SDPA issues in some Florence versions.
57
+ vision_model = AutoModelForCausalLM.from_pretrained(
58
+ VISION_MODEL_ID,
59
+ trust_remote_code=True,
60
+ # some modelling code accepts this kwarg; it's safe if ignored by the model class
61
+ attn_implementation="eager"
62
+ ).to(device)
 
 
63
 
64
+ processor = AutoProcessor.from_pretrained(
65
+ VISION_MODEL_ID,
66
+ trust_remote_code=True
67
+ )
68
+ print("✅ Florence-2 loaded")
69
+ except Exception as e:
70
+ # Provide clearer startup error in logs (Spaces will show this)
71
+ print("❌ Failed loading Florence-2: ")
72
+ traceback.print_exc()
73
+ # Re-raise so the Space fails loudly (you can check logs)
74
+ raise
75
+
76
+ # Cloud strategist (unchanged)
77
+ from smolagents import CodeAgent, InferenceClientModel
78
 
 
 
 
79
  strategist = CodeAgent(
80
  tools=[],
81
+ model=InferenceClientModel(model_id="meta-llama/Llama-3.2-3B-Instruct")
 
 
82
  )
83
 
 
 
 
84
  @app.get("/")
85
  def home():
86
  return {
 
91
 
92
  @app.post("/analyze/")
93
  async def analyze(file: UploadFile = File(...)):
94
+ try:
95
+ img_bytes = await file.read()
96
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
97
+ width, height = image.size
98
 
99
+ text_input = "button, icon, start, attack, confirm, close, x, claim, menu"
 
 
 
 
100
 
101
+ inputs = processor(text=text_input, images=image, return_tensors="pt").to(device)
 
 
 
 
 
102
 
103
+ with torch.no_grad():
104
+ generated_ids = vision_model.generate(
105
+ input_ids=inputs.get("input_ids"),
106
+ pixel_values=inputs.get("pixel_values"),
107
+ max_new_tokens=512
108
+ )
109
 
110
+ prediction = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
111
+
112
+ vision_data = processor.post_process_generation(
113
+ prediction,
114
+ task="<CAPTION_TO_PHRASE_GROUNDING>",
115
+ image_size=(width, height)
116
+ )
117
 
118
+ prompt = f"""
119
+ You are a mobile game bot for a Redmi 9i (720x1600).
120
+ Visual Data: {vision_data}
 
121
 
122
+ Task: Pick the best element to click to progress.
123
+ Rule: You must ONLY output: tap X Y
124
+ No other text.
125
  """
126
 
127
+ decision = strategist.run(prompt)
128
 
129
+ return {
130
+ "status": "success",
131
+ "decision": str(decision).strip(),
132
+ "debug": vision_data
133
+ }
134
+ except Exception as exc:
135
+ traceback.print_exc()
136
+ return {"status": "error", "detail": str(exc)}