Buckets:
| # Install the required dependencies before running this script: | |
| # pip install torch torchvision | |
| # pip install gradio==6.9.0 | |
| # pip install transformers==5.3.0 | |
| # pip install supervision==0.27.0.post2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import supervision as sv | |
| import json | |
| import ast | |
| import re | |
| from PIL import Image | |
| from threading import Thread | |
| from transformers import ( | |
| Qwen3_5ForConditionalGeneration, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = ( | |
| torch.bfloat16 | |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| else torch.float16 | |
| ) | |
| MODEL_NAME = "prithivMLmods/Polaris-VGA-4B-Post1.0e" | |
| BRIGHT_YELLOW = sv.Color(r=255, g=230, b=0) | |
| DARK_OUTLINE = sv.Color(r=40, g=40, b=40) | |
| BLACK = sv.Color(r=0, g=0, b=0) | |
| print(f"Loading model: {MODEL_NAME} ...") | |
| qwen_model = Qwen3_5ForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=DTYPE, | |
| device_map=DEVICE, | |
| ).eval() | |
| qwen_processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| print("Model loaded.") | |
| def safe_parse_json(text: str): | |
| text = text.strip() | |
| text = re.sub(r"^```(json)?", "", text) | |
| text = re.sub(r"```$", "", text) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| return ast.literal_eval(text) | |
| except Exception: | |
| return {} | |
| def normalize_1000_coord(v): | |
| try: | |
| v = float(v) | |
| except Exception: | |
| return 0.0 | |
| return max(0.0, min(1.0, v / 1000.0)) | |
| def parse_pointer_response(raw_text: str): | |
| parsed = safe_parse_json(raw_text) | |
| result = {"points": []} | |
| if not isinstance(parsed, list): | |
| return result | |
| for item in parsed: | |
| if not isinstance(item, dict): | |
| continue | |
| label = item.get("label", "") | |
| # Preferred format: point_2d | |
| if "point_2d" in item and isinstance(item["point_2d"], (list, tuple)) and len(item["point_2d"]) == 2: | |
| x, y = item["point_2d"] | |
| result["points"].append({ | |
| "label": label, | |
| "x": normalize_1000_coord(x), | |
| "y": normalize_1000_coord(y), | |
| }) | |
| continue | |
| # Fallback: bbox_2d -> convert bbox center to point | |
| if "bbox_2d" in item and isinstance(item["bbox_2d"], (list, tuple)) and len(item["bbox_2d"]) == 4: | |
| x1, y1, x2, y2 = item["bbox_2d"] | |
| cx = (float(x1) + float(x2)) / 2.0 | |
| cy = (float(y1) + float(y2)) / 2.0 | |
| result["points"].append({ | |
| "label": label, | |
| "x": normalize_1000_coord(cx), | |
| "y": normalize_1000_coord(cy), | |
| }) | |
| return result | |
| def annotate_point_image(image: Image.Image, result: dict): | |
| if not isinstance(image, Image.Image) or not isinstance(result, dict): | |
| return image | |
| image = image.convert("RGB") | |
| ow, oh = image.size | |
| if "points" not in result or not result["points"]: | |
| return image | |
| pts = [] | |
| valid_points = [] | |
| for p in result["points"]: | |
| if "x" not in p or "y" not in p: | |
| continue | |
| px = int(float(p["x"]) * ow) | |
| py = int(float(p["y"]) * oh) | |
| px = max(0, min(ow - 1, px)) | |
| py = max(0, min(oh - 1, py)) | |
| pts.append([px, py]) | |
| valid_points.append(p) | |
| if not pts: | |
| return image | |
| kp = sv.KeyPoints(xy=np.array(pts, dtype=np.int32).reshape(1, -1, 2)) | |
| scene = np.array(image.copy()) | |
| scene = sv.VertexAnnotator(radius=10, color=DARK_OUTLINE).annotate( | |
| scene=scene, key_points=kp | |
| ) | |
| scene = sv.VertexAnnotator(radius=6, color=BRIGHT_YELLOW).annotate( | |
| scene=scene, key_points=kp | |
| ) | |
| labels = [p.get("label", "") for p in valid_points] | |
| if any(labels): | |
| tb, vl = [], [] | |
| for i, p in enumerate(valid_points): | |
| if labels[i]: | |
| cx, cy = pts[i] | |
| tb.append([cx - 2, cy - 2, cx + 2, cy + 2]) | |
| vl.append(labels[i]) | |
| if tb: | |
| scene = sv.LabelAnnotator( | |
| color=BRIGHT_YELLOW, | |
| text_color=BLACK, | |
| text_scale=0.5, | |
| text_thickness=1, | |
| text_padding=5, | |
| text_position=sv.Position.TOP_CENTER, | |
| color_lookup=sv.ColorLookup.INDEX, | |
| ).annotate( | |
| scene=scene, | |
| detections=sv.Detections(xyxy=np.array(tb)), | |
| labels=vl, | |
| ) | |
| return Image.fromarray(scene) | |
| def process_image_pointer(image, prompt): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please provide a point prompt.") | |
| original_image = image.convert("RGB") | |
| model_image = original_image.copy() | |
| model_image.thumbnail((512, 512)) | |
| full_prompt = ( | |
| f"Provide 2d point coordinates for {prompt}. " | |
| f"Report in JSON format as a list like " | |
| f'[{{"point_2d": [x, y], "label": "{prompt}"}}]. ' | |
| f"If you cannot provide points, do not return bounding boxes." | |
| ) | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": model_image}, | |
| {"type": "text", "text": full_prompt}, | |
| ], | |
| }] | |
| text = qwen_processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = qwen_processor( | |
| text=[text], | |
| images=[model_image], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(qwen_model.device) | |
| streamer = TextIteratorStreamer( | |
| qwen_processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| timeout=120, | |
| ) | |
| thread = Thread( | |
| target=qwen_model.generate, | |
| kwargs=dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| use_cache=True, | |
| temperature=1.5, | |
| min_p=0.1, | |
| ), | |
| ) | |
| thread.start() | |
| full_text = "" | |
| for tok in streamer: | |
| full_text += tok | |
| yield original_image, full_text | |
| thread.join() | |
| result = parse_pointer_response(full_text) | |
| annotated = annotate_point_image(original_image.copy(), result) | |
| yield annotated, json.dumps(result, indent=2) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Pointer") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| prompt_input = gr.Textbox( | |
| label="Point Prompt", | |
| placeholder="e.g. left eye, car headlight, person hand" | |
| ) | |
| button = gr.Button("Run Pointing") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Pointed Image") | |
| output_text = gr.Textbox(label="Point Output", lines=12) | |
| button.click( | |
| fn=process_image_pointer, | |
| inputs=[image_input, prompt_input], | |
| outputs=[output_image, output_text], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, ssr_mode=False) |
Xet Storage Details
- Size:
- 7.39 kB
- Xet hash:
- 7939b76361f51bea4e6c96acd037978fd30bff265ae970070ee334d8202c09e6
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.