magiv1 / app.py
maggidev's picture
Update app.py
904ab42 verified
import streamlit as st
import torch
import numpy as np
import urllib.request
from PIL import Image
from transformers import AutoModel
# ===============================
# CONFIGURAÇÕES GLOBAIS (CPU ONLY)
# ===============================
torch.set_num_threads(2)
torch.set_grad_enabled(False)
st.set_page_config(
page_title="Manga Whisperer",
layout="wide"
)
# ===============================
# LOAD MODEL (CACHEADO)
# ===============================
@st.cache_resource
def load_model():
model = AutoModel.from_pretrained(
"ragavsachdeva/magi",
trust_remote_code=True
)
model.eval()
return model
model = load_model()
# ===============================
# IMAGE LOADER
# ===============================
@st.cache_data(show_spinner=False)
def read_image_as_np_array(image_input):
if isinstance(image_input, str) and image_input.startswith("http"):
image = Image.open(
urllib.request.urlopen(image_input)
).convert("L").convert("RGB")
else:
image = Image.open(image_input).convert("L").convert("RGB")
return np.array(image)
# ===============================
# DETECTION STAGE
# ===============================
@st.cache_data(show_spinner=True)
def run_detection(image_input, params):
image = read_image_as_np_array(image_input)
with torch.inference_mode():
result = model.predict_detections_and_associations(
[image],
**params
)[0]
return image, result
# ===============================
# OCR STAGE
# ===============================
@st.cache_data(show_spinner=True)
def run_ocr(image, detection_result):
if not detection_result["texts"]:
return None
text_bboxes = [detection_result["texts"]]
with torch.inference_mode():
ocr_results = model.predict_ocr(
[image],
text_bboxes
)
return ocr_results[0]
# ===============================
# TRANSCRIPT STAGE
# ===============================
def generate_transcript(detection_result, ocr_result):
if ocr_result is None:
return "Nenhum texto detectado."
return model.generate_transcript_for_single_image(
detection_result,
ocr_result
)
def generate_structural_dialogue(detection_result, threshold=0.4):
texts = detection_result.get("texts", [])
characters = detection_result.get("characters", [])
scores = detection_result.get(
"text_character_matching_scores", []
)
dialogue_lines = []
for text_id in range(len(texts)):
if text_id < len(scores) and scores[text_id]:
char_scores = scores[text_id]
best_char = max(
range(len(char_scores)),
key=lambda i: char_scores[i]
)
best_score = char_scores[best_char]
if best_score >= threshold:
line = (
f"Text {text_id} → "
f"Character {best_char} "
f"(score: {best_score:.2f})"
)
else:
line = f"Text {text_id} → Narration / Uncertain"
else:
line = f"Text {text_id} → Narration / Uncertain"
dialogue_lines.append(line)
return "\n".join(dialogue_lines)
# ===============================
# UI
# ===============================
st.markdown(
"""
<style>
.title {
font-size: 2.2em;
text-align: center;
color: #ffffff;
font-family: 'Comic Sans MS', cursive;
margin-bottom: 0.2em;
}
.subtitle {
font-size: 1.2em;
text-align: center;
color: #cccccc;
margin-bottom: 1em;
}
</style>
<div class="title">Manga Whisperer</div>
<div class="subtitle">Automatic Comic Transcription (CPU Optimized)</div>
""",
unsafe_allow_html=True
)
# ===============================
# SIDEBAR
# ===============================
st.sidebar.markdown("### Mode")
generate_detections = st.sidebar.toggle(
"Generate detections", True
)
generate_transcript_toggle = st.sidebar.toggle(
"Generate transcript (slow)", False
)
st.sidebar.markdown("### Thresholds")
params = dict(
character_detection_threshold=st.sidebar.slider(
"Character detection", 0.0, 1.0, 0.30, 0.01
),
panel_detection_threshold=st.sidebar.slider(
"Panel detection", 0.0, 1.0, 0.20, 0.01
),
text_detection_threshold=st.sidebar.slider(
"Text detection", 0.0, 1.0, 0.25, 0.01
),
character_character_matching_threshold=st.sidebar.slider(
"Character-character matching", 0.0, 1.0, 0.70, 0.01
),
text_character_matching_threshold=st.sidebar.slider(
"Text-character matching", 0.0, 1.0, 0.40, 0.01
),
)
# ===============================
# INPUT IMAGE
# ===============================
image_input = st.file_uploader(
"Upload an image",
type=["png", "jpg", "jpeg"]
)
# ===============================
# MAIN PIPELINE
# ===============================
if image_input is not None:
st.markdown("### Prediction")
# 1️⃣ DETECTION
image, detection_result = run_detection(
image_input,
params
)
# 2️⃣ VISUALIZATION
if generate_detections:
vis = model.visualise_single_image_prediction(
image,
detection_result
)
st.image(vis, caption="Detections")
# 3️⃣ STRUCTURAL DIALOGUE (NO OCR)
if generate_transcript_toggle:
structural_dialogue = generate_structural_dialogue(
detection_result
)
st.text_area(
"Structural Dialogue (MAGI output)",
structural_dialogue,
height=300
)