Mateusz Mr贸z
Refactor Florence2Processor and utils.py for improved readability and maintainability; add colab setup script with necessary patches and dependencies for model execution.
bfabb11
| # ============================================================================== | |
| # 1) INSTALACJA PAKIET脫W | |
| # ============================================================================== | |
| from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig | |
| from IPython.display import display | |
| from io import BytesIO | |
| from PIL import Image, ImageDraw | |
| import math | |
| import json | |
| import torch | |
| import requests | |
| import re | |
| !pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning" | |
| # ============================================================================== | |
| # 2) IMPORTY | |
| # ============================================================================== | |
| # ============================================================================== | |
| # 3) POBRANIE OBRAZU | |
| # ============================================================================== | |
| # def download_imgbb_image(page_url): | |
| # print(f"Pobieranie obrazu ze strony: {page_url}") | |
| # html = requests.get(page_url).text | |
| # img_url = re.search(r'https://i\.ibb\.co/[A-Za-z0-9/_\-]+\.(?:png|jpg|jpeg|webp)', html).group(0) | |
| # print(f"Znaleziono bezpo艣redni link: {img_url}") | |
| # img_bytes = requests.get(img_url).content | |
| # return Image.open(BytesIO(img_bytes)).convert("RGB") | |
| # page_url = "https://ibb.co/cchLK038" | |
| # pil_img = download_imgbb_image(page_url) | |
| # print("Obraz zosta艂 pomy艣lnie pobrany.") | |
| pil_img = Image.open("./1.png").convert("RGB") | |
| # ============================================================================== | |
| # 4) ZA艁ADOWANIE MODELU I PROCESORA | |
| # ============================================================================== | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| print(f"\nU偶ywane urz膮dzenie: {device}, typ danych: {dtype}") | |
| model_id = "MattyMroz/magiv3" | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) | |
| config._attn_implementation = "eager" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| config=config, | |
| trust_remote_code=True, | |
| torch_dtype=dtype | |
| ).to(device).eval() | |
| if not hasattr(model, "_supports_sdpa"): | |
| setattr(model, "_supports_sdpa", False) | |
| print("Model i procesor za艂adowane pomy艣lnie.") | |
| # ============================================================================== | |
| # 5) ZAAWANSOWANA WIZUALIZACJA I PRZETWARZANIE | |
| # ============================================================================== | |
| def create_visualization(image, data, detailed_mode=False): | |
| """ | |
| Rysuje zaawansowan膮 wizualizacj臋 detekcji i asocjacji na obrazie. | |
| Args: | |
| image: Obraz wej艣ciowy | |
| data: Dane JSON z wynikami | |
| detailed_mode: Je艣li True, rysuje wszystko z JSON (OCR, grounding). | |
| Je艣li False (domy艣lnie), rysuje tylko detekcje i asocjacje. | |
| """ | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| # ZMIANA: Zaktualizowana paleta kolor贸w i grubo艣ci linii | |
| colors = { | |
| "panels": "green", | |
| "texts": "red", | |
| "characters": "blue", | |
| "tails": "purple", | |
| "cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"], | |
| "speaker_line": "magenta", | |
| "ocr": "orange", | |
| "grounding": "cyan", | |
| } | |
| line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1, "ocr": 2, "grounding": 2} | |
| def get_box_center(box): | |
| x1, y1, x2, y2 = box | |
| return (x1 + x2) / 2, (y1 + y2) / 2 | |
| def draw_dashed_line(draw_obj, p1, p2, fill, width, dash_len=10): | |
| x1, y1 = p1 | |
| x2, y2 = p2 | |
| dx, dy = x2 - x1, y2 - y1 | |
| dist = math.sqrt(dx**2 + dy**2) | |
| if dist == 0: | |
| return | |
| for i in range(0, int(dist / dash_len), 2): | |
| start = (x1 + (dx * i * dash_len) / dist, | |
| y1 + (dy * i * dash_len) / dist) | |
| end = (x1 + (dx * (i + 1) * dash_len) / dist, | |
| y1 + (dy * (i + 1) * dash_len) / dist) | |
| draw_obj.line([start, end], fill=fill, width=width) | |
| # Rysowanie Bounding Box贸w | |
| for category, bboxes in data.get("detections", {}).items(): | |
| if category in colors: | |
| for box in bboxes: | |
| draw.rectangle( | |
| box, outline=colors[category], width=line_widths.get(category, 1)) | |
| # Rysowanie Klastr贸w Postaci | |
| clusters = data.get("associations", {}).get("character_cluster_labels", []) | |
| characters = data.get("detections", {}).get("characters", []) | |
| if clusters and characters: | |
| unique_labels = sorted(list(set(clusters))) | |
| for i, label in enumerate(unique_labels): | |
| color = colors["cluster_colors"][i % len(colors["cluster_colors"])] | |
| indices = [j for j, l in enumerate(clusters) if l == label] | |
| if len(indices) > 1: | |
| for k in range(len(indices) - 1): | |
| p1 = get_box_center(characters[indices[k]]) | |
| p2 = get_box_center(characters[indices[k+1]]) | |
| draw.line([p1, p2], fill=color, width=2) | |
| # Rysowanie Linii M贸wc贸w | |
| texts = data.get("detections", {}).get("texts", []) | |
| speaker_associations = data.get("associations", {}).get( | |
| "text_character_associations", []) | |
| if speaker_associations and texts and characters: | |
| for text_idx, char_idx in speaker_associations: | |
| if text_idx < len(texts) and char_idx < len(characters): | |
| p1 = get_box_center(texts[text_idx]) | |
| p2 = get_box_center(characters[char_idx]) | |
| draw_dashed_line( | |
| draw, p1, p2, fill=colors["speaker_line"], width=1) | |
| # Tryb wybredny - rysowanie dodatkowych element贸w z JSON | |
| if detailed_mode: | |
| # Rysowanie OCR boxes | |
| ocr_data = data.get("ocr", []) | |
| for ocr_item in ocr_data: | |
| box = ocr_item.get("box") | |
| if box: | |
| draw.rectangle(box, outline=colors["ocr"], width=line_widths["ocr"]) | |
| # Rysowanie Grounding boxes | |
| grounding_data = data.get("grounding", []) | |
| for grounding_item in grounding_data: | |
| boxes = grounding_item.get("boxes", []) | |
| for box in boxes: | |
| draw.rectangle(box, outline=colors["grounding"], width=line_widths["grounding"]) | |
| return img_draw | |
| def process_image(image, caption_for_grounding="elf girl", detailed_mode=False): | |
| """ | |
| Przetwarza obraz i tworzy wizualizacj臋. | |
| Args: | |
| image: Obraz wej艣ciowy | |
| caption_for_grounding: Caption dla character grounding | |
| detailed_mode: Je艣li True, wizualizacja zawiera wszystko z JSON (OCR, grounding). | |
| Je艣li False (domy艣lnie), tylko detekcje i asocjacje. | |
| """ | |
| print("\n--- Rozpoczynanie przetwarzania obrazu ---") | |
| images = [image] | |
| captions = [caption_for_grounding] | |
| print("1/3: Uruchamianie OCR...") | |
| ocr_results = model.predict_ocr(images, processor)[0] | |
| print("2/3: Uruchamianie detekcji i asocjacji...") | |
| detection_results = model.predict_detections_and_associations(images, processor)[ | |
| 0] | |
| print("3/3: Uruchamianie 'Character Grounding'...") | |
| grounding_results = model.predict_character_grounding( | |
| images, captions, processor)[0] | |
| final_json = { | |
| "ocr": [{"text": text, "box": box} for text, box in zip(ocr_results.get("ocr_texts", []), ocr_results.get("bboxes", []))], | |
| "detections": {"panels": detection_results.get("panels", []), "texts": detection_results.get("texts", []), "characters": detection_results.get("characters", []), "tails": detection_results.get("tails", [])}, | |
| "associations": {"character_cluster_labels": detection_results.get("character_cluster_labels", []), "text_character_associations": detection_results.get("text_character_associations", []), "text_tail_associations": detection_results.get("text_tail_associations", []), "is_essential_text": detection_results.get("is_essential_text", [])}, | |
| "grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))] | |
| } | |
| mode_text = "wybredny (wszystkie elementy)" if detailed_mode else "domy艣lny (detekcje i asocjacje)" | |
| print(f"Tworzenie wizualizacji w trybie: {mode_text}") | |
| visualization_image = create_visualization(image, final_json, detailed_mode=detailed_mode) | |
| print("--- Zako艅czono przetwarzanie ---") | |
| return final_json, visualization_image | |
| # ============================================================================== | |
| # 6) URUCHOMIENIE I WY艢WIETLENIE WYNIK脫W | |
| # ============================================================================== | |
| # Tryb wizualizacji: | |
| # detailed_mode=False (domy艣lny) - rysuje tylko detekcje i asocjacje (obecne kolory) | |
| # detailed_mode=True (wybredny) - rysuje wszystko z JSON: OCR (pomara艅czowy), grounding (cyjan) | |
| json_output, image_output = process_image( | |
| pil_img, caption_for_grounding="elf girl", detailed_mode=True) | |
| print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====") | |
| print(json.dumps(json_output, indent=2)) | |
| print("\n\n===== WIZUALIZACJA (przed filtrowaniem) =====") | |
| display(image_output) | |