from fastapi import FastAPI from pydantic import BaseModel import requests from io import BytesIO from PIL import Image import numpy as np import cv2 import pytesseract import torch from transformers import LlavaProcessor, LlavaForConditionalGeneration app = FastAPI() # ===================== # LOAD MODEL # ===================== model_id = "llava-hf/llava-1.5-7b-hf" processor = LlavaProcessor.from_pretrained(model_id) model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ) # ===================== # REQUEST FORMAT # ===================== class ImageRequest(BaseModel): url: str # ===================== # LOAD IMAGE # ===================== def load_image_from_url(url): response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") return image # ===================== # OCR # ===================== def preprocess(image): img = np.array(image) gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) return Image.fromarray(gray) def extract_text_ocr(image): processed = preprocess(image) config = r'--oem 3 --psm 6' return pytesseract.image_to_string(processed, config=config).strip() # ===================== # LLaVA # ===================== def get_caption(image): prompt = "USER: \nDescribe the image in detail and extract any visible text.\nASSISTANT:" inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=200) decoded = processor.decode(output[0], skip_special_tokens=True) # ================= CLEANING STEP ================= if "ASSISTANT:" in decoded: decoded = decoded.split("ASSISTANT:")[-1] if "USER:" in decoded: decoded = decoded.split("USER:")[-1] return decoded.strip() # ===================== # MAIN PIPELINE # ===================== def process_image(url): image = load_image_from_url(url) ocr_text = extract_text_ocr(image) caption = get_caption(image) return { "type": "image", "processed_text": f"{caption} {ocr_text}" } # ===================== # API ROUTE # ===================== @app.post("/predict") def predict(req: ImageRequest): return process_image(req.url)