tddf commited on
Commit
2e07297
·
verified ·
1 Parent(s): f79d9bc

Update Main.py

Browse files
Files changed (1) hide show
  1. Main.py +52 -79
Main.py CHANGED
@@ -1,10 +1,13 @@
 
1
  import io
2
  import streamlit as st
3
  import torch
4
  from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
5
  from PIL import Image
6
 
7
- # ==================== Настройки страницы ====================
 
 
8
  st.set_page_config(
9
  page_title="LightOnOCR • Распознай текст",
10
  page_icon="📄",
@@ -12,66 +15,53 @@ st.set_page_config(
12
  initial_sidebar_state="expanded"
13
  )
14
 
15
- # ==================== Кастомный CSS ====================
16
  st.markdown("""
17
- <style>
18
- .main { background: linear-gradient(180deg, #f8f9fa, #e9f0f7); }
19
- h1 { color: #1e3a8a; text-align: center; margin-bottom: 0.2rem; }
20
- .stButton > button {
21
- background: linear-gradient(90deg, #3b82f6, #1e40af);
22
- color: white;
23
- border-radius: 12px;
24
- padding: 12px 32px;
25
- font-weight: 600;
26
- border: none;
27
- box-shadow: 0 4px 15px rgba(59, 130, 246, 0.3);
28
- }
29
- .stButton > button:hover {
30
- transform: translateY(-2px);
31
- box-shadow: 0 8px 20px rgba(59, 130, 246, 0.4);
32
- }
33
- .result-box {
34
- background: #ffffff;
35
- border-radius: 16px;
36
- padding: 24px;
37
- box-shadow: 0 10px 30px rgba(0, 0, 0, 0.08);
38
- border: 1px solid #e5e7eb;
39
- margin-top: 20px;
40
- }
41
- .header-emoji { font-size: 3.5rem; text-align: center; margin: 10px 0; }
42
- </style>
43
  """, unsafe_allow_html=True)
44
 
45
- # ==================== Загрузка модели ====================
46
- @st.cache_resource(show_spinner="Загрузка модели LightOnOCR-1B-1025...")
47
  def load_model():
48
  model_name = "lightonai/LightOnOCR-1B-1025"
49
 
50
- if torch.backends.mps.is_available():
51
- device = "mps"
52
- dtype = torch.float32
53
- elif torch.cuda.is_available():
54
- device = "cuda"
55
- dtype = torch.bfloat16
56
- else:
57
- device = "cpu"
58
- dtype = torch.float32
59
 
60
  model = LightOnOcrForConditionalGeneration.from_pretrained(
61
  model_name,
62
  torch_dtype=dtype,
63
  trust_remote_code=True,
64
- device_map=None # загружаем вручную
65
  ).to(device)
66
 
67
  processor = LightOnOcrProcessor.from_pretrained(model_name)
68
-
69
  return processor, model, device, dtype
70
 
71
- # ==================== Загрузка изображения ====================
 
 
 
 
 
 
 
 
 
 
 
72
  def load_image():
73
  uploaded_file = st.file_uploader(
74
- "📸 Загрузите изображение (фото, скан, документ)",
75
  type=['png', 'jpg', 'jpeg', 'webp']
76
  )
77
  if uploaded_file is not None:
@@ -80,45 +70,26 @@ def load_image():
80
  return Image.open(io.BytesIO(image_data)).convert('RGB')
81
  return None
82
 
83
- # ==================== Основной интерфейс ====================
84
- st.markdown('<div class="header-emoji">📄✨</div>', unsafe_allow_html=True)
85
- st.title("LightOnOCR")
86
- st.markdown("**Мгновенное распознавание текста на английском и других языках**")
87
- st.caption("Модель LightOnOCR-1B-1025 • Отлично работает с документами, чеками, таблицами и фото")
88
-
89
- # Загружаем модель один раз
90
- processor, model, device, dtype = load_model()
91
-
92
- # Сайдбар
93
- with st.sidebar:
94
- st.markdown("### 🚀 О модели")
95
- st.info("LightOnOCR-1B-1025 — компактная end-to-end модель для OCR и понимания документов.")
96
- st.markdown("**Поддержка:** Английский + латиница, таблицы, сложная вёрстка")
97
- st.caption(f"Устройство: **{device.upper()}** • dtype: **{dtype}**")
98
-
99
- # Загрузка изображения
100
  img = load_image()
101
 
102
- # Кнопка распознавания
103
  if st.button("🔍 Распознать текст", use_container_width=True, type="primary"):
104
  if img is None:
105
- st.error("Пожалуйста, сначала загрузите изображение")
106
  else:
107
- with st.spinner("Распознавание текста… (на CPU может занять 10–30 секунд)"):
108
- # Правильный способ работы с этой моделью (chat template)
109
  conversation = [
110
  {
111
  "role": "user",
112
  "content": [
113
- {"type": "image"},
114
- {"type": "te
115
-
116
-
117
- xt", "text": "Extract all the text from this image accurately. Preserve formatting, tables, and line breaks as much as possible."}
118
  ]
119
  }
120
  ]
121
 
 
122
  inputs = processor.apply_chat_template(
123
  conversation,
124
  add_generation_prompt=True,
@@ -127,9 +98,9 @@ xt", "text": "Extract all the text from this image accurately. Preserve formatti
127
  return_tensors="pt"
128
  )
129
 
130
- # Переносим на устройство
131
  inputs = {
132
- k: v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device)
133
  for k, v in inputs.items()
134
  }
135
 
@@ -141,24 +112,26 @@ xt", "text": "Extract all the text from this image accurately. Preserve formatti
141
  temperature=0.0
142
  )
143
 
144
- # Убираем промпт, оставляем только сгенерированный текст
145
  generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
146
- generated_text = processor.decode(generated_ids, skip_special_tokens=True)
 
 
 
147
 
148
- # Вывод результата
149
- st.success("✅ Распознавание завершено!")
150
  st.markdown('<div class="result-box">', unsafe_allow_html=True)
151
  st.subheader("📝 Распознанный текст")
152
- st.markdown(f"```\n{generated_text}\n```")
153
  st.markdown('</div>', unsafe_allow_html=True)
154
 
155
- # Кнопка скачивания
156
  st.download_button(
157
- label="💾 Скачать текст (.txt)",
158
  data=generated_text,
159
  file_name="recognized_text.txt",
160
  mime="text/plain"
161
  )
162
 
163
  st.markdown("---")
164
- st.markdown("**Сделано на базе [lightonai/LightOnOCR-1B-1025](https://huggingface.co/lightonai/LightOnOCR-1B-1025)**")
 
1
+ import os
2
  import io
3
  import streamlit as st
4
  import torch
5
  from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
6
  from PIL import Image
7
 
8
+ # Ускоряем скачивание на HF Spaces
9
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
10
+
11
  st.set_page_config(
12
  page_title="LightOnOCR • Распознай текст",
13
  page_icon="📄",
 
15
  initial_sidebar_state="expanded"
16
  )
17
 
18
+ # Простой CSS
19
  st.markdown("""
20
+ <style>
21
+ .main { background: linear-gradient(180deg, #f8f9fa, #e9f0f7); }
22
+ .result-box {
23
+ background: #ffffff;
24
+ border-radius: 16px;
25
+ padding: 24px;
26
+ box-shadow: 0 10px 30px rgba(0,0,0,0.08);
27
+ margin-top: 20px;
28
+ }
29
+ .header-emoji { font-size: 3.5rem; text-align: center; margin: 15px 0; }
30
+ </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """, unsafe_allow_html=True)
32
 
33
+ @st.cache_resource(show_spinner="⏳ Загрузка модели LightOnOCR-1B-1025...\nЭто может занять 2–6 минут при первом запуске на CPU")
 
34
  def load_model():
35
  model_name = "lightonai/LightOnOCR-1B-1025"
36
 
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
39
 
40
  model = LightOnOcrForConditionalGeneration.from_pretrained(
41
  model_name,
42
  torch_dtype=dtype,
43
  trust_remote_code=True,
 
44
  ).to(device)
45
 
46
  processor = LightOnOcrProcessor.from_pretrained(model_name)
47
+
48
  return processor, model, device, dtype
49
 
50
+ # ====================== Заголовок ======================
51
+ st.markdown('<div class="header-emoji">📄✨</div>', unsafe_allow_html=True)
52
+ st.title("LightOnOCR")
53
+ st.markdown("**Распознавание текста с изображений**")
54
+ st.caption("Модель: lightonai/LightOnOCR-1B-1025")
55
+
56
+ # ====================== Загрузка модели ======================
57
+ processor, model, device, dtype = load_model()
58
+
59
+ st.sidebar.success(f"✅ Модель загружена на **{device.upper()}**")
60
+
61
+ # ====================== Загрузка изображения ======================
62
  def load_image():
63
  uploaded_file = st.file_uploader(
64
+ "📸 Загрузите изображение (png, jpg, jpeg, webp)",
65
  type=['png', 'jpg', 'jpeg', 'webp']
66
  )
67
  if uploaded_file is not None:
 
70
  return Image.open(io.BytesIO(image_data)).convert('RGB')
71
  return None
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  img = load_image()
74
 
75
+ # ====================== Распознавание ======================
76
  if st.button("🔍 Распознать текст", use_container_width=True, type="primary"):
77
  if img is None:
78
+ st.error("Сначала загрузите изображение")
79
  else:
80
+ with st.spinner("Распознавание текста..."):
81
+ # Правильный формат для LightOnOCR (по официальному примеру)
82
  conversation = [
83
  {
84
  "role": "user",
85
  "content": [
86
+ {"type": "image"}, # изображение передаётся отдельно
87
+ {"type": "text", "text": "Extract all the text from this image accurately. Preserve original formatting, tables, and line breaks as much as possible."}
 
 
 
88
  ]
89
  }
90
  ]
91
 
92
+ # Подготовка inputs
93
  inputs = processor.apply_chat_template(
94
  conversation,
95
  add_generation_prompt=True,
 
98
  return_tensors="pt"
99
  )
100
 
101
+ # Перенос на устройство
102
  inputs = {
103
+ k: (v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device))
104
  for k, v in inputs.items()
105
  }
106
 
 
112
  temperature=0.0
113
  )
114
 
115
+ # Убираем промпт
116
  generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
117
+ generated_text = processor.decode(generated_ids,
118
+
119
+
120
+ skip_special_tokens=True)
121
 
122
+ # Результат
123
+ st.success("✅ Готово!")
124
  st.markdown('<div class="result-box">', unsafe_allow_html=True)
125
  st.subheader("📝 Распознанный текст")
126
+ st.code(generated_text, language=None)
127
  st.markdown('</div>', unsafe_allow_html=True)
128
 
 
129
  st.download_button(
130
+ "💾 Скачать как .txt",
131
  data=generated_text,
132
  file_name="recognized_text.txt",
133
  mime="text/plain"
134
  )
135
 
136
  st.markdown("---")
137
+ st.caption("Сделано на базе [lightonai/LightOnOCR-1B-1025](https://huggingface.co/lightonai/LightOnOCR-1B-1025)")