File size: 5,858 Bytes
2e07297
fad1dff
 
 
 
 
 
a345f45
2e07297
 
fad1dff
 
 
 
 
6e8c2a0
7de5a1a
fad1dff
8558925
 
 
 
 
 
 
 
 
 
 
 
fad1dff
7de5a1a
22cde00
fad1dff
 
9477b8f
2e07297
 
fad1dff
 
 
 
 
6e8c2a0
9477b8f
fad1dff
3a8d636
8558925
 
3a8d636
8558925
3a8d636
fad1dff
 
2e07297
fad1dff
9477b8f
fad1dff
 
 
 
 
 
a345f45
8558925
 
 
 
 
 
 
 
 
 
fad1dff
 
8558925
fad1dff
 
3a8d636
fad1dff
a345f45
1586730
de827c2
 
 
 
fad1dff
 
 
 
8558925
de827c2
fad1dff
 
 
de827c2
fad1dff
de827c2
 
 
 
 
 
 
 
a345f45
de827c2
 
fad1dff
 
de827c2
a345f45
 
 
 
fad1dff
4b7419e
3a8d636
9477b8f
fad1dff
bf74a50
1586730
8558925
1586730
 
9477b8f
fad1dff
de827c2
1586730
 
a345f45
1586730
8558925
1586730
 
 
 
de827c2
 
 
 
 
 
 
 
 
 
 
 
 
6e8c2a0
fad1dff
2e07297
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import io
import streamlit as st
import torch
from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
from PIL import Image

# Ускоряем скачивание
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

st.set_page_config(
    page_title="LightOnOCR • Распознай текст",
    page_icon="📄",
    layout="centered",
    initial_sidebar_state="expanded"
)

st.markdown("""
    <style>
        .main { background: linear-gradient(180deg, #f8f9fa, #e9f0f7); }
        .header-emoji { font-size: 3.5rem; text-align: center; margin: 15px 0; }
        .result-box {
            background: #ffffff;
            border-radius: 16px;
            padding: 24px;
            box-shadow: 0 10px 30px rgba(0, 0, 0, 0.08);
            border: 1px solid #e5e7eb;
            margin-top: 20px;
        }
    </style>
""", unsafe_allow_html=True)

@st.cache_resource(show_spinner="⏳ Загрузка модели LightOnOCR-1B-1025... (2–6 минут при первом запуске)")
def load_model():
    model_name = "lightonai/LightOnOCR-1B-1025"
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    model = LightOnOcrForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=dtype,
        trust_remote_code=True,
    ).to(device)

    processor = LightOnOcrProcessor.from_pretrained(model_name)

    if processor.tokenizer.pad_token is None:
        processor.tokenizer.pad_token = processor.tokenizer.eos_token

    return processor, model, device, dtype

def load_image():
    uploaded_file = st.file_uploader(
        "📸 Загрузите изображение (png, jpg, jpeg, webp)",
        type=['png', 'jpg', 'jpeg', 'webp']
    )
    if uploaded_file is not None:
        image_data = uploaded_file.getvalue()
        st.image(image_data, use_container_width=True, caption="Загруженное изображение")
        return Image.open(io.BytesIO(image_data)).convert('RGB')
    return None

# ==================== Интерфейс ====================
st.markdown('<div class="header-emoji">📄✨</div>', unsafe_allow_html=True)
st.title("LightOnOCR")
st.markdown("**Распознавание текста с изображений**")
st.caption("Модель: lightonai/LightOnOCR-1B-1025")

processor, model, device, dtype = load_model()

with st.sidebar:
    st.success(f"✅ Модель загружена на **{device.upper()}**")

img = load_image()

# ==================== Распознавание ====================
if st.button("🔍 Распознать текст", use_container_width=True, type="primary"):
    if img is None:
        st.error("Сначала загрузите изображение")
    else:
        with st.spinner("Распознавание текста... (5–30 сек на CPU)"):
            
            # Промпт для модели
            prompt = "Extract all the text from this image as accurately as possible. Preserve line breaks, formatting and tables."

            # 1. Получаем текстовый шаблон чата (без токенизации)
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]
            text_prompt = processor.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
            )

            # 2. Правильный вызов процессора (ключевой момент!)
            inputs = processor(
                text=[text_prompt],
                images=[[img]],                    # двойной список — обязательно!
                return_tensors="pt",
                padding=True,
                size={"longest_edge": 1540}        # рекомендуемый размер модели
            )

            # Переносим на устройство
            inputs = {
                k: (v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device))
                for k, v in inputs.items()
            }

            # Генерация
            output_ids = model.generate(
                **inputs,
                max_new_tokens=2048,
                do_sample=False,
                temperature=0.0,
                num_beams=1,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
            )

            # Убираем промпт
            prompt_length = inputs["input_ids"].shape[1]
            generated_ids = output_ids[0, prompt_length:]
            
            generated_text = processor.decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            ).strip()

            # Результат
            st.success("✅ Распознавание завершено!")
            st.markdown('<div class="result-box">', unsafe_allow_html=True)
            st.subheader("📝 Распознанный текст")
            st.code(generated_text, language=None)
            st.markdown('</div>', unsafe_allow_html=True)

            st.download_button(
                label="💾 Скачать как .txt",
                data=generated_text,
                file_name="recognized_text.txt",
                mime="text/plain"
            )

st.markdown("---")
st.caption("Сделано на базе [lightonai/LightOnOCR-1B-1025](https://huggingface.co/lightonai/LightOnOCR-1B-1025)")