Omnia-cy commited on
Commit
2ed8a2b
·
verified ·
1 Parent(s): 1634971

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ import pytesseract
9
+ import torch
10
+ from transformers import LlavaProcessor, LlavaForConditionalGeneration
11
+
12
+ app = FastAPI()
13
+
14
+ # =====================
15
+ # LOAD MODEL
16
+ # =====================
17
+ model_id = "llava-hf/llava-1.5-7b-hf"
18
+
19
+ processor = LlavaProcessor.from_pretrained(model_id)
20
+ model = LlavaForConditionalGeneration.from_pretrained(
21
+ model_id,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
+
26
+ # =====================
27
+ # REQUEST FORMAT
28
+ # =====================
29
+ class ImageRequest(BaseModel):
30
+ url: str
31
+
32
+ # =====================
33
+ # LOAD IMAGE
34
+ # =====================
35
+ def load_image_from_url(url):
36
+ response = requests.get(url)
37
+ image = Image.open(BytesIO(response.content)).convert("RGB")
38
+ return image
39
+
40
+ # =====================
41
+ # OCR
42
+ # =====================
43
+ def preprocess(image):
44
+ img = np.array(image)
45
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
46
+ return Image.fromarray(gray)
47
+
48
+ def extract_text_ocr(image):
49
+ processed = preprocess(image)
50
+ config = r'--oem 3 --psm 6'
51
+ return pytesseract.image_to_string(processed, config=config).strip()
52
+
53
+ # =====================
54
+ # LLaVA
55
+ # =====================
56
+ def get_caption(image):
57
+ prompt = "USER: <image>\nDescribe the image in detail and extract any visible text.\nASSISTANT:"
58
+
59
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
60
+
61
+ with torch.no_grad():
62
+ output = model.generate(**inputs, max_new_tokens=200)
63
+
64
+ return processor.decode(output[0], skip_special_tokens=True)
65
+
66
+ # =====================
67
+ # MAIN PIPELINE
68
+ # =====================
69
+ def process_image(url):
70
+ image = load_image_from_url(url)
71
+
72
+ ocr_text = extract_text_ocr(image)
73
+ caption = get_caption(image)
74
+
75
+ return {
76
+ "type": "image",
77
+ "processed_text": f"{caption} {ocr_text}"
78
+ }
79
+
80
+ # =====================
81
+ # API ROUTE
82
+ # =====================
83
+ @app.post("/predict")
84
+ def predict(req: ImageRequest):
85
+ return process_image(req.url)