magiv3 / colab.py
Mateusz Mr贸z
Refactor Florence2Processor and utils.py for improved readability and maintainability; add colab setup script with necessary patches and dependencies for model execution.
bfabb11
# ==============================================================================
# 1) INSTALACJA PAKIET脫W
# ==============================================================================
from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
from IPython.display import display
from io import BytesIO
from PIL import Image, ImageDraw
import math
import json
import torch
import requests
import re
!pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning"
# ==============================================================================
# 2) IMPORTY
# ==============================================================================
# ==============================================================================
# 3) POBRANIE OBRAZU
# ==============================================================================
# def download_imgbb_image(page_url):
# print(f"Pobieranie obrazu ze strony: {page_url}")
# html = requests.get(page_url).text
# img_url = re.search(r'https://i\.ibb\.co/[A-Za-z0-9/_\-]+\.(?:png|jpg|jpeg|webp)', html).group(0)
# print(f"Znaleziono bezpo艣redni link: {img_url}")
# img_bytes = requests.get(img_url).content
# return Image.open(BytesIO(img_bytes)).convert("RGB")
# page_url = "https://ibb.co/cchLK038"
# pil_img = download_imgbb_image(page_url)
# print("Obraz zosta艂 pomy艣lnie pobrany.")
pil_img = Image.open("./1.png").convert("RGB")
# ==============================================================================
# 4) ZA艁ADOWANIE MODELU I PROCESORA
# ==============================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"\nU偶ywane urz膮dzenie: {device}, typ danych: {dtype}")
model_id = "MattyMroz/magiv3"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
config._attn_implementation = "eager"
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
trust_remote_code=True,
torch_dtype=dtype
).to(device).eval()
if not hasattr(model, "_supports_sdpa"):
setattr(model, "_supports_sdpa", False)
print("Model i procesor za艂adowane pomy艣lnie.")
# ==============================================================================
# 5) ZAAWANSOWANA WIZUALIZACJA I PRZETWARZANIE
# ==============================================================================
def create_visualization(image, data, detailed_mode=False):
"""
Rysuje zaawansowan膮 wizualizacj臋 detekcji i asocjacji na obrazie.
Args:
image: Obraz wej艣ciowy
data: Dane JSON z wynikami
detailed_mode: Je艣li True, rysuje wszystko z JSON (OCR, grounding).
Je艣li False (domy艣lnie), rysuje tylko detekcje i asocjacje.
"""
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
# ZMIANA: Zaktualizowana paleta kolor贸w i grubo艣ci linii
colors = {
"panels": "green",
"texts": "red",
"characters": "blue",
"tails": "purple",
"cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"],
"speaker_line": "magenta",
"ocr": "orange",
"grounding": "cyan",
}
line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1, "ocr": 2, "grounding": 2}
def get_box_center(box):
x1, y1, x2, y2 = box
return (x1 + x2) / 2, (y1 + y2) / 2
def draw_dashed_line(draw_obj, p1, p2, fill, width, dash_len=10):
x1, y1 = p1
x2, y2 = p2
dx, dy = x2 - x1, y2 - y1
dist = math.sqrt(dx**2 + dy**2)
if dist == 0:
return
for i in range(0, int(dist / dash_len), 2):
start = (x1 + (dx * i * dash_len) / dist,
y1 + (dy * i * dash_len) / dist)
end = (x1 + (dx * (i + 1) * dash_len) / dist,
y1 + (dy * (i + 1) * dash_len) / dist)
draw_obj.line([start, end], fill=fill, width=width)
# Rysowanie Bounding Box贸w
for category, bboxes in data.get("detections", {}).items():
if category in colors:
for box in bboxes:
draw.rectangle(
box, outline=colors[category], width=line_widths.get(category, 1))
# Rysowanie Klastr贸w Postaci
clusters = data.get("associations", {}).get("character_cluster_labels", [])
characters = data.get("detections", {}).get("characters", [])
if clusters and characters:
unique_labels = sorted(list(set(clusters)))
for i, label in enumerate(unique_labels):
color = colors["cluster_colors"][i % len(colors["cluster_colors"])]
indices = [j for j, l in enumerate(clusters) if l == label]
if len(indices) > 1:
for k in range(len(indices) - 1):
p1 = get_box_center(characters[indices[k]])
p2 = get_box_center(characters[indices[k+1]])
draw.line([p1, p2], fill=color, width=2)
# Rysowanie Linii M贸wc贸w
texts = data.get("detections", {}).get("texts", [])
speaker_associations = data.get("associations", {}).get(
"text_character_associations", [])
if speaker_associations and texts and characters:
for text_idx, char_idx in speaker_associations:
if text_idx < len(texts) and char_idx < len(characters):
p1 = get_box_center(texts[text_idx])
p2 = get_box_center(characters[char_idx])
draw_dashed_line(
draw, p1, p2, fill=colors["speaker_line"], width=1)
# Tryb wybredny - rysowanie dodatkowych element贸w z JSON
if detailed_mode:
# Rysowanie OCR boxes
ocr_data = data.get("ocr", [])
for ocr_item in ocr_data:
box = ocr_item.get("box")
if box:
draw.rectangle(box, outline=colors["ocr"], width=line_widths["ocr"])
# Rysowanie Grounding boxes
grounding_data = data.get("grounding", [])
for grounding_item in grounding_data:
boxes = grounding_item.get("boxes", [])
for box in boxes:
draw.rectangle(box, outline=colors["grounding"], width=line_widths["grounding"])
return img_draw
def process_image(image, caption_for_grounding="elf girl", detailed_mode=False):
"""
Przetwarza obraz i tworzy wizualizacj臋.
Args:
image: Obraz wej艣ciowy
caption_for_grounding: Caption dla character grounding
detailed_mode: Je艣li True, wizualizacja zawiera wszystko z JSON (OCR, grounding).
Je艣li False (domy艣lnie), tylko detekcje i asocjacje.
"""
print("\n--- Rozpoczynanie przetwarzania obrazu ---")
images = [image]
captions = [caption_for_grounding]
print("1/3: Uruchamianie OCR...")
ocr_results = model.predict_ocr(images, processor)[0]
print("2/3: Uruchamianie detekcji i asocjacji...")
detection_results = model.predict_detections_and_associations(images, processor)[
0]
print("3/3: Uruchamianie 'Character Grounding'...")
grounding_results = model.predict_character_grounding(
images, captions, processor)[0]
final_json = {
"ocr": [{"text": text, "box": box} for text, box in zip(ocr_results.get("ocr_texts", []), ocr_results.get("bboxes", []))],
"detections": {"panels": detection_results.get("panels", []), "texts": detection_results.get("texts", []), "characters": detection_results.get("characters", []), "tails": detection_results.get("tails", [])},
"associations": {"character_cluster_labels": detection_results.get("character_cluster_labels", []), "text_character_associations": detection_results.get("text_character_associations", []), "text_tail_associations": detection_results.get("text_tail_associations", []), "is_essential_text": detection_results.get("is_essential_text", [])},
"grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))]
}
mode_text = "wybredny (wszystkie elementy)" if detailed_mode else "domy艣lny (detekcje i asocjacje)"
print(f"Tworzenie wizualizacji w trybie: {mode_text}")
visualization_image = create_visualization(image, final_json, detailed_mode=detailed_mode)
print("--- Zako艅czono przetwarzanie ---")
return final_json, visualization_image
# ==============================================================================
# 6) URUCHOMIENIE I WY艢WIETLENIE WYNIK脫W
# ==============================================================================
# Tryb wizualizacji:
# detailed_mode=False (domy艣lny) - rysuje tylko detekcje i asocjacje (obecne kolory)
# detailed_mode=True (wybredny) - rysuje wszystko z JSON: OCR (pomara艅czowy), grounding (cyjan)
json_output, image_output = process_image(
pil_img, caption_for_grounding="elf girl", detailed_mode=True)
print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====")
print(json.dumps(json_output, indent=2))
print("\n\n===== WIZUALIZACJA (przed filtrowaniem) =====")
display(image_output)