Spaces:
Sleeping
Sleeping
File size: 2,331 Bytes
2ed8a2b 226a010 2ed8a2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | from fastapi import FastAPI
from pydantic import BaseModel
import requests
from io import BytesIO
from PIL import Image
import numpy as np
import cv2
import pytesseract
import torch
from transformers import LlavaProcessor, LlavaForConditionalGeneration
app = FastAPI()
# =====================
# LOAD MODEL
# =====================
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# =====================
# REQUEST FORMAT
# =====================
class ImageRequest(BaseModel):
url: str
# =====================
# LOAD IMAGE
# =====================
def load_image_from_url(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# =====================
# OCR
# =====================
def preprocess(image):
img = np.array(image)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return Image.fromarray(gray)
def extract_text_ocr(image):
processed = preprocess(image)
config = r'--oem 3 --psm 6'
return pytesseract.image_to_string(processed, config=config).strip()
# =====================
# LLaVA
# =====================
def get_caption(image):
prompt = "USER: <image>\nDescribe the image in detail and extract any visible text.\nASSISTANT:"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=200)
decoded = processor.decode(output[0], skip_special_tokens=True)
# ================= CLEANING STEP =================
if "ASSISTANT:" in decoded:
decoded = decoded.split("ASSISTANT:")[-1]
if "USER:" in decoded:
decoded = decoded.split("USER:")[-1]
return decoded.strip()
# =====================
# MAIN PIPELINE
# =====================
def process_image(url):
image = load_image_from_url(url)
ocr_text = extract_text_ocr(image)
caption = get_caption(image)
return {
"type": "image",
"processed_text": f"{caption} {ocr_text}"
}
# =====================
# API ROUTE
# =====================
@app.post("/predict")
def predict(req: ImageRequest):
return process_image(req.url) |