Spaces:
Sleeping
Sleeping
| import io, base64, torch, re | |
| from fastapi import FastAPI, Request | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| # --- RECURSION-SAFE PATCH --- | |
| import transformers.dynamic_module_utils as dynamic_utils | |
| if not hasattr(dynamic_utils, "_original_get_imports_fixed"): | |
| dynamic_utils._original_get_imports_fixed = dynamic_utils.get_imports | |
| def patched_get_imports(filename): | |
| imports = dynamic_utils._original_get_imports_fixed(filename) | |
| return [imp for imp in imports if imp not in ["flash_attn", "xformers"]] | |
| dynamic_utils.get_imports = patched_get_imports | |
| DEVICE = "cpu" | |
| MODEL_ID = "microsoft/Florence-2-large" | |
| app = FastAPI() | |
| model, processor = None, None | |
| async def load_model(): | |
| global model, processor | |
| print("⏳ Loading Heavy Duty Brain...") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, attn_implementation="eager").to(DEVICE).eval() | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print("✅ Ready") | |
| async def analyze(request: Request): | |
| try: | |
| data = await request.json() | |
| img_b64 = data.get("image") | |
| dw, dh = data.get("width"), data.get("height") | |
| image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") | |
| real_w, real_h = image.size | |
| # We use <DETECTION> which is more robust than phrase grounding for UI | |
| # We search for numbers and buttons | |
| prompt = "<DETECTION>" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=512) | |
| decoded = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| results = processor.post_process_generation(decoded, task="<DETECTION>", image_size=(real_w, real_h)) | |
| # Also run a quick OCR pass for the digits specifically if detection is vague | |
| prompt_ocr = "<OCR_WITH_REGION>" | |
| inputs_ocr = processor(text=prompt_ocr, images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| gen_ocr = model.generate(input_ids=inputs_ocr["input_ids"], pixel_values=inputs_ocr["pixel_values"], max_new_tokens=512) | |
| dec_ocr = processor.batch_decode(gen_ocr, skip_special_tokens=False)[0] | |
| res_ocr = processor.post_process_generation(dec_ocr, task="<OCR_WITH_REGION>", image_size=(real_w, real_h)) | |
| found_digits = [] | |
| # Process OCR Results | |
| ocr_data = res_ocr.get("<OCR_WITH_REGION>", {}) | |
| labels = ocr_data.get("labels", []) | |
| boxes = ocr_data.get("quad_boxes", []) # Use quad_boxes for OCR | |
| for i, label in enumerate(labels): | |
| digit_match = re.search(r'\d', label) | |
| if digit_match: | |
| val = int(digit_match.group()) | |
| # Quad box is [x1,y1,x2,y2,x3,y3,x4,y4] | |
| box = boxes[i] | |
| cx = sum(box[0::2]) / 4 | |
| cy = sum(box[1::2]) / 4 | |
| found_digits.append({ | |
| "val": val, | |
| "x": int(cx * (dw / real_w)), | |
| "y": int(cy * (dh / real_h)) | |
| }) | |
| # Final filtering | |
| unique_digits = {d['val']: d for d in found_digits}.values() | |
| final_list = sorted(unique_digits, key=lambda x: x['val'], reverse=True) | |
| return { | |
| "action": "sequence" if final_list else "wait", | |
| "sequence": final_list, | |
| "reason": f"Found {len(final_list)} digits: {list(unique_digits.keys())}" | |
| } | |
| except Exception as e: | |
| return {"action": "wait", "reason": str(e)} |