MandreOCR / app.py
sterepando's picture
Update app.py
31b0bf1 verified
import io
import uvicorn
from PIL import Image
from fastapi import FastAPI, UploadFile, File, Response
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
# --- 1. Глобальная загрузка компонентов ---
model = None
processor = None
device = "cpu"
try:
print(">>> Инициализация загрузки LightOnOCR-1B (с trust_remote_code=True)...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f">>> Устройство: {device}")
repo_id = "lightonai/LightOnOCR-1B-1025"
# 1. Загружаем процессор
# ВАЖНО: trust_remote_code=True позволяет загрузить кастомный код процессора из репозитория,
# который умеет правильно обрабатывать аргумент 'images' и вставлять токены.
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
# 2. Загружаем модель
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForImageTextToText.from_pretrained(
repo_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device)
print(">>> Все компоненты успешно загружены!")
except Exception as e:
print(f"КРИТИЧЕСКАЯ ОШИБКА загрузки: {e}")
app = FastAPI(title="LightOnOCR Final API", version="5.0.0")
@app.post("/api/ocr")
async def run_ocr(file: UploadFile = File(...)):
if model is None or processor is None:
return Response(content="Сервер не готов.", status_code=503)
try:
# 1. Загрузка картинки
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 2. Формирование промпта
# Для этой модели обычно достаточно простого промпта, но важно,
# чтобы процессор сам обработал вставку <image> токенов.
prompt = "<image>\nTranscribe the text in this image."
# 3. Обработка через процессор
# Теперь, с trust_remote_code=True, этот вызов должен работать корректно
# и вернуть input_ids, pixel_values и, возможно, image_sizes.
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Переносим все тензоры на устройство
inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
# 4. Генерация
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
pad_token_id=processor.tokenizer.pad_token_id
)
# 5. Декодирование
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Очистка результата от промпта (простая эвристика)
clean_text = generated_text.replace(prompt.replace("<image>", ""), "").strip()
# Дополнительная очистка, если модель возвращает мусор в начале
if "Transcribe" in clean_text:
clean_text = clean_text.split("image.")[-1].strip()
return {"text": clean_text}
except Exception as e:
import traceback
traceback.print_exc()
return Response(content=f"Server Error: {str(e)}", status_code=500)
@app.get("/")
async def home():
return {"message": "OCR API Ready. POST image to /api/ocr"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)