Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import pytesseract | |
| import torch | |
| from transformers import LlavaProcessor, LlavaForConditionalGeneration | |
| app = FastAPI() | |
| # ===================== | |
| # LOAD MODEL | |
| # ===================== | |
| model_id = "llava-hf/llava-1.5-7b-hf" | |
| processor = LlavaProcessor.from_pretrained(model_id) | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # ===================== | |
| # REQUEST FORMAT | |
| # ===================== | |
| class ImageRequest(BaseModel): | |
| url: str | |
| # ===================== | |
| # LOAD IMAGE | |
| # ===================== | |
| def load_image_from_url(url): | |
| response = requests.get(url) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| return image | |
| # ===================== | |
| # OCR | |
| # ===================== | |
| def preprocess(image): | |
| img = np.array(image) | |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| return Image.fromarray(gray) | |
| def extract_text_ocr(image): | |
| processed = preprocess(image) | |
| config = r'--oem 3 --psm 6' | |
| return pytesseract.image_to_string(processed, config=config).strip() | |
| # ===================== | |
| # LLaVA | |
| # ===================== | |
| def get_caption(image): | |
| prompt = "USER: <image>\nDescribe the image in detail and extract any visible text.\nASSISTANT:" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=200) | |
| decoded = processor.decode(output[0], skip_special_tokens=True) | |
| # ================= CLEANING STEP ================= | |
| if "ASSISTANT:" in decoded: | |
| decoded = decoded.split("ASSISTANT:")[-1] | |
| if "USER:" in decoded: | |
| decoded = decoded.split("USER:")[-1] | |
| return decoded.strip() | |
| # ===================== | |
| # MAIN PIPELINE | |
| # ===================== | |
| def process_image(url): | |
| image = load_image_from_url(url) | |
| ocr_text = extract_text_ocr(image) | |
| caption = get_caption(image) | |
| return { | |
| "type": "image", | |
| "processed_text": f"{caption} {ocr_text}" | |
| } | |
| # ===================== | |
| # API ROUTE | |
| # ===================== | |
| def predict(req: ImageRequest): | |
| return process_image(req.url) |