tddf commited on
Commit
8558925
·
verified ·
1 Parent(s): 3a8d636

Update Main.py

Browse files
Files changed (1) hide show
  1. Main.py +52 -46
Main.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
6
  from PIL import Image
7
 
8
- # Ускоряем скачивание на HF Spaces
9
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
10
 
11
  st.set_page_config(
@@ -15,22 +15,22 @@ st.set_page_config(
15
  initial_sidebar_state="expanded"
16
  )
17
 
18
- # Простой CSS
19
  st.markdown("""
20
- <style>
21
- .main { background: linear-gradient(180deg, #f8f9fa, #e9f0f7); }
22
- .result-box {
23
- background: #ffffff;
24
- border-radius: 16px;
25
- padding: 24px;
26
- box-shadow: 0 10px 30px rgba(0,0,0,0.08);
27
- margin-top: 20px;
28
- }
29
- .header-emoji { font-size: 3.5rem; text-align: center; margin: 15px 0; }
30
- </style>
 
31
  """, unsafe_allow_html=True)
32
 
33
- @st.cache_resource(show_spinner="⏳ Загрузка модели LightOnOCR-1B-1025...\nЭто может занять 2–6 минут при первом запуске на CPU")
34
  def load_model():
35
  model_name = "lightonai/LightOnOCR-1B-1025"
36
 
@@ -44,21 +44,12 @@ def load_model():
44
  ).to(device)
45
 
46
  processor = LightOnOcrProcessor.from_pretrained(model_name)
47
-
48
- return processor, model, device, dtype
49
-
50
- # ====================== Заголовок ======================
51
- st.markdown('<div class="header-emoji">📄✨</div>', unsafe_allow_html=True)
52
- st.title("LightOnOCR")
53
- st.markdown("**Распознавание текста с изображений**")
54
- st.caption("Модель: lightonai/LightOnOCR-1B-1025")
55
 
56
- # ====================== Загрузка модели ======================
57
- processor, model, device, dtype = load_model()
58
 
59
- st.sidebar.success(f"✅ Модель загружена на **{device.upper()}**")
60
 
61
- # ====================== Загрузка изображения ======================
62
  def load_image():
63
  uploaded_file = st.file_uploader(
64
  "📸 Загрузите изображение (png, jpg, jpeg, webp)",
@@ -70,69 +61,84 @@ def load_image():
70
  return Image.open(io.BytesIO(image_data)).convert('RGB')
71
  return None
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  img = load_image()
74
 
75
- # ====================== Распознавание ======================
76
  if st.button("🔍 Распознать текст", use_container_width=True, type="primary"):
77
  if img is None:
78
  st.error("Сначала загрузите изображение")
79
  else:
80
- with st.spinner("Распознавание текста... (может занять 5–20 сек на CPU)"):
81
 
82
- # Правильный формат для LightOnOCR (только изображение + промпт)
83
  conversation = [
84
  {
85
  "role": "user",
86
  "content": [
87
- {"type": "image"}, # изображение передаётся автоматически через processor
88
- {"type": "text", "text": "Extract all the text from this image as accurately as possible. Output clean text with preserved line breaks and formatting."}
89
  ]
90
  }
91
  ]
92
 
93
- # Подготовка inputs (processor обработает и изображение, и текст)
94
  inputs = processor.apply_chat_template(
95
  conversation,
96
  add_generation_prompt=True,
97
  tokenize=True,
98
  return_dict=True,
99
- return_tensors="pt",
100
- # Важно: передаём само PIL-изображение
101
- images=img
102
  )
103
 
104
- # Переносим на устройство
105
- inputs = {
106
- k: (v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device))
107
- for k, v in inputs.items()
108
- }
 
 
 
109
 
110
  # Генерация
111
  output_ids = model.generate(
112
  **inputs,
113
  max_new_tokens=2048,
114
- do_sample=False,
 
 
 
115
  temperature=0.0,
116
- num_beams=1, # для стабильности
117
  pad_token_id=processor.tokenizer.pad_token_id,
118
  eos_token_id=processor.tokenizer.eos_token_id,
119
  )
120
 
121
- # Убираем входной промпт — оставляем только сгенерированный текст
122
  prompt_length = inputs["input_ids"].shape[1]
123
  generated_ids = output_ids[0, prompt_length:]
124
 
125
  generated_text = processor.decode(
126
- generated_ids,
127
  skip_special_tokens=True,
128
  clean_up_tokenization_spaces=True
129
  ).strip()
130
 
131
- # Вывод результата
132
  st.success("✅ Распознавание завершено!")
133
  st.markdown('<div class="result-box">', unsafe_allow_html=True)
134
  st.subheader("📝 Распознанный текст")
135
- st.code(generated_text, language=None) # лучше чем markdown для больших блоков
136
  st.markdown('</div>', unsafe_allow_html=True)
137
 
138
  st.download_button(
 
5
  from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
6
  from PIL import Image
7
 
8
+ # Ускоряем скачивание
9
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
10
 
11
  st.set_page_config(
 
15
  initial_sidebar_state="expanded"
16
  )
17
 
 
18
  st.markdown("""
19
+ <style>
20
+ .main { background: linear-gradient(180deg, #f8f9fa, #e9f0f7); }
21
+ .header-emoji { font-size: 3.5rem; text-align: center; margin: 15px 0; }
22
+ .result-box {
23
+ background: #ffffff;
24
+ border-radius: 16px;
25
+ padding: 24px;
26
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.08);
27
+ border: 1px solid #e5e7eb;
28
+ margin-top: 20px;
29
+ }
30
+ </style>
31
  """, unsafe_allow_html=True)
32
 
33
+ @st.cache_resource(show_spinner="⏳ Загрузка модели LightOnOCR-1B-1025...\n(2–6 минут при первом запуске)")
34
  def load_model():
35
  model_name = "lightonai/LightOnOCR-1B-1025"
36
 
 
44
  ).to(device)
45
 
46
  processor = LightOnOcrProcessor.from_pretrained(model_name)
 
 
 
 
 
 
 
 
47
 
48
+ if processor.tokenizer.pad_token is None:
49
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
50
 
51
+ return processor, model, device, dtype
52
 
 
53
  def load_image():
54
  uploaded_file = st.file_uploader(
55
  "📸 Загрузите изображение (png, jpg, jpeg, webp)",
 
61
  return Image.open(io.BytesIO(image_data)).convert('RGB')
62
  return None
63
 
64
+ # ==================== Интерфейс ====================
65
+ st.markdown('<div class="header-emoji">📄✨</div>', unsafe_allow_html=True)
66
+ st.title("LightOnOCR")
67
+ st.markdown("**Распознавание текста с изображений**")
68
+ st.caption("Модель: lightonai/LightOnOCR-1B-1025")
69
+
70
+ processor, model, device, dtype = load_model()
71
+
72
+ with st.sidebar:
73
+ st.success(f"✅ Модель загружена на **{device.upper()}**")
74
+
75
  img = load_image()
76
 
77
+ # ==================== Распознавание ====================
78
  if st.button("🔍 Распознать текст", use_container_width=True, type="primary"):
79
  if img is None:
80
  st.error("Сначала загрузите изображение")
81
  else:
82
+ with st.spinner("Распознавание текста... (5–20 сек на CPU)"):
83
 
84
+ # Правильный формат разговора (без передачи images здесь)
85
  conversation = [
86
  {
87
  "role": "user",
88
  "content": [
89
+ {"type": "image"},
90
+ {"type": "text", "text": "Extract all the text from this image accurately. Preserve original formatting, tables, and line breaks as much as possible."}
91
  ]
92
  }
93
  ]
94
 
95
+ # Применяем шаблон чата
96
  inputs = processor.apply_chat_template(
97
  conversation,
98
  add_generation_prompt=True,
99
  tokenize=True,
100
  return_dict=True,
101
+ return_tensors="pt"
 
 
102
  )
103
 
104
+ # Важно: добавляем pixel_values отдельно
105
+ pixel_values = processor.image_processor(img, return_tensors="pt").pixel_values
106
+ inputs["pixel_values"] = pixel_values.to(device=device, dtype=dtype)
107
+
108
+ # Переносим остальные тензоры
109
+ for k, v in inputs.items():
110
+ if isinstance(v, torch.Tensor) and k != "pixel_values":
111
+ inputs[k] = v.to(device=device)
112
 
113
  # Генерация
114
  output_ids = model.generate(
115
  **inputs,
116
  max_new_tokens=2048,
117
+ do_sa
118
+
119
+
120
+ mple=False,
121
  temperature=0.0,
122
+ num_beams=1,
123
  pad_token_id=processor.tokenizer.pad_token_id,
124
  eos_token_id=processor.tokenizer.eos_token_id,
125
  )
126
 
127
+ # Убираем промпт
128
  prompt_length = inputs["input_ids"].shape[1]
129
  generated_ids = output_ids[0, prompt_length:]
130
 
131
  generated_text = processor.decode(
132
+ generated_ids,
133
  skip_special_tokens=True,
134
  clean_up_tokenization_spaces=True
135
  ).strip()
136
 
137
+ # Результат
138
  st.success("✅ Распознавание завершено!")
139
  st.markdown('<div class="result-box">', unsafe_allow_html=True)
140
  st.subheader("📝 Распознанный текст")
141
+ st.code(generated_text, language=None)
142
  st.markdown('</div>', unsafe_allow_html=True)
143
 
144
  st.download_button(