game-brain-ai / main.py
Sairesh's picture
Update main.py
a01e17a verified
import io, base64, torch, re
from fastapi import FastAPI, Request
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
# --- RECURSION-SAFE PATCH ---
import transformers.dynamic_module_utils as dynamic_utils
if not hasattr(dynamic_utils, "_original_get_imports_fixed"):
dynamic_utils._original_get_imports_fixed = dynamic_utils.get_imports
def patched_get_imports(filename):
imports = dynamic_utils._original_get_imports_fixed(filename)
return [imp for imp in imports if imp not in ["flash_attn", "xformers"]]
dynamic_utils.get_imports = patched_get_imports
DEVICE = "cpu"
MODEL_ID = "microsoft/Florence-2-large"
app = FastAPI()
model, processor = None, None
@app.on_event("startup")
async def load_model():
global model, processor
print("⏳ Loading Heavy Duty Brain...")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, attn_implementation="eager").to(DEVICE).eval()
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
print("✅ Ready")
@app.post("/analyze/")
async def analyze(request: Request):
try:
data = await request.json()
img_b64 = data.get("image")
dw, dh = data.get("width"), data.get("height")
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
real_w, real_h = image.size
# We use <DETECTION> which is more robust than phrase grounding for UI
# We search for numbers and buttons
prompt = "<DETECTION>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=512)
decoded = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
results = processor.post_process_generation(decoded, task="<DETECTION>", image_size=(real_w, real_h))
# Also run a quick OCR pass for the digits specifically if detection is vague
prompt_ocr = "<OCR_WITH_REGION>"
inputs_ocr = processor(text=prompt_ocr, images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
gen_ocr = model.generate(input_ids=inputs_ocr["input_ids"], pixel_values=inputs_ocr["pixel_values"], max_new_tokens=512)
dec_ocr = processor.batch_decode(gen_ocr, skip_special_tokens=False)[0]
res_ocr = processor.post_process_generation(dec_ocr, task="<OCR_WITH_REGION>", image_size=(real_w, real_h))
found_digits = []
# Process OCR Results
ocr_data = res_ocr.get("<OCR_WITH_REGION>", {})
labels = ocr_data.get("labels", [])
boxes = ocr_data.get("quad_boxes", []) # Use quad_boxes for OCR
for i, label in enumerate(labels):
digit_match = re.search(r'\d', label)
if digit_match:
val = int(digit_match.group())
# Quad box is [x1,y1,x2,y2,x3,y3,x4,y4]
box = boxes[i]
cx = sum(box[0::2]) / 4
cy = sum(box[1::2]) / 4
found_digits.append({
"val": val,
"x": int(cx * (dw / real_w)),
"y": int(cy * (dh / real_h))
})
# Final filtering
unique_digits = {d['val']: d for d in found_digits}.values()
final_list = sorted(unique_digits, key=lambda x: x['val'], reverse=True)
return {
"action": "sequence" if final_list else "wait",
"sequence": final_list,
"reason": f"Found {len(final_list)} digits: {list(unique_digits.keys())}"
}
except Exception as e:
return {"action": "wait", "reason": str(e)}