Sanket17 commited on
Commit
3304bbb
·
verified ·
1 Parent(s): 5c1773f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -1,30 +1,66 @@
1
- from fastapi import FastAPI, UploadFile, File
 
2
  from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
4
  import torch
 
 
 
5
 
6
- # Initialize FastAPI
7
- app = FastAPI()
8
 
9
- # Load the model and processor
10
- MODEL_NAME = "microsoft/OmniParser-blip2-caption"
11
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
12
- model = AutoModelForVisualQuestionAnswering.from_pretrained(MODEL_NAME)
13
 
14
- @app.get("/")
15
- async def home():
16
- return {"message": "Welcome to OmniParser API!"}
 
17
 
18
- @app.post("/predict/")
19
- async def predict(file: UploadFile = File(...)):
20
- # Read and preprocess the image
21
- image = Image.open(file.file).convert("RGB")
22
- inputs = processor(images=image, return_tensors="pt")
23
 
24
- # Perform inference
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
 
28
- # Decode results
29
- caption = processor.decode(outputs.logits.argmax(-1).squeeze().tolist())
30
- return {"caption": caption}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
  from PIL import Image
4
+ import io
5
  import torch
6
+ import base64
7
+ from ultralytics import YOLO
8
+ from utils import check_ocr_box, get_som_labeled_img # Import utility functions
9
 
10
+ # Initialize Flask app
11
+ app = Flask(__name__)
12
 
13
+ # Load YOLO model
14
+ yolo_model = YOLO('best.pt').to('cuda')
 
 
15
 
16
+ # Load Florence model
17
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
18
+ model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
19
+ caption_model_processor = {'processor': processor, 'model': model}
20
 
21
+ @app.route('/predict', methods=['POST'])
22
+ def predict():
23
+ data = request.get_json()
24
+ image_data = base64.b64decode(data['image'])
25
+ image = Image.open(io.BytesIO(image_data))
26
 
27
+ # Process the image and get predictions
28
+ result_image, parsed_content, coordinates = process(image, data['box_threshold'], data['iou_threshold'])
 
29
 
30
+ # Encode the result image back to base64
31
+ buffered = io.BytesIO()
32
+ result_image.save(buffered, format="PNG")
33
+ result_image_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
34
+
35
+ return jsonify({
36
+ 'result_image': result_image_str,
37
+ 'parsed_content': parsed_content,
38
+ 'coordinates': coordinates
39
+ })
40
+
41
+ def process(image_input, box_threshold, iou_threshold):
42
+ # Your image processing code here
43
+ image_save_path = 'imgs/saved_image_demo.png'
44
+ image_input.save(image_save_path)
45
+ image = Image.open(image_save_path)
46
+ box_overlay_ratio = image.size[0] / 3200
47
+ draw_bbox_config = {
48
+ 'text_scale': 0.8 * box_overlay_ratio,
49
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
50
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
51
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
52
+ }
53
+
54
+ # Implement check_ocr_box, get_som_labeled_img as in your reference
55
+ # Replace these function calls with actual implementations
56
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=True)
57
+ text, ocr_bbox = ocr_bbox_rslt
58
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold)
59
+
60
+ result_image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
61
+ parsed_content_list = '\n'.join(parsed_content_list)
62
+
63
+ return result_image, str(parsed_content_list), str(label_coordinates)
64
+
65
+ if __name__ == '__main__':
66
+ app.run(debug=True, host='0.0.0.0', port=5000)