Sanket17 commited on
Commit
26966db
·
verified ·
1 Parent(s): afdc45a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -54
app.py CHANGED
@@ -1,62 +1,50 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
- import io
6
  import torch
7
- import base64
8
- from ultralytics import YOLO
9
- from utils import check_ocr_box, get_som_labeled_img # Import utility functions
10
 
 
11
  app = FastAPI()
12
 
13
- # Load YOLO model
14
- yolo_model = YOLO('best.pt').to('cuda')
15
-
16
- # Load Florence model
17
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
18
- model = AutoModelForCausalLM.from_pretrained("microsoft/OmniParser/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
19
- caption_model_processor = {'processor': processor, 'model': model}
20
-
21
- @app.post("/predict")
22
- async def predict(image: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1):
23
- image_data = await image.read()
24
- image = Image.open(io.BytesIO(image_data))
25
-
26
- # Process the image and get predictions
27
- result_image, parsed_content, coordinates = process(image, box_threshold, iou_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Encode the result image back to base64
30
- buffered = io.BytesIO()
31
- result_image.save(buffered, format="PNG")
32
- result_image_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
33
-
34
- return JSONResponse(content={
35
- 'result_image': result_image_str,
36
- 'parsed_content': parsed_content,
37
- 'coordinates': coordinates
38
- })
39
 
40
- def process(image_input, box_threshold, iou_threshold):
41
- # Your image processing code here
42
- image_save_path = 'imgs/saved_image_demo.png'
43
- image_input.save(image_save_path)
44
- image = Image.open(image_save_path)
45
- box_overlay_ratio = image.size[0] / 3200
46
- draw_bbox_config = {
47
- 'text_scale': 0.8 * box_overlay_ratio,
48
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
49
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
50
- 'thickness': max(int(3 * box_overlay_ratio), 1),
51
- }
52
-
53
- # Implement check_ocr_box, get_som_labeled_img as in your reference
54
- # Replace these function calls with actual implementations
55
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=True)
56
- text, ocr_bbox = ocr_bbox_rslt
57
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold)
58
-
59
- result_image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
60
- parsed_content_list = '\n'.join(parsed_content_list)
61
-
62
- return result_image, str(parsed_content_list), str(label_coordinates)
 
1
+ from fastapi import FastAPI, UploadFile, Form
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
4
  from PIL import Image
 
5
  import torch
6
+ import uvicorn
7
+ import os
 
8
 
9
+ # Initialize FastAPI app
10
  app = FastAPI()
11
 
12
+ # Access the Hugging Face token from the secret section
13
+ hf_token = os.getenv("HP_token")
14
+
15
+ # Load model and processor with the token
16
+ processor = AutoProcessor.from_pretrained("Sanket17/hello", use_auth_token=hf_token)
17
+ model = AutoModelForVisualQuestionAnswering.from_pretrained("Sanket17/hello", use_auth_token=hf_token)
18
+
19
+ @app.post("/vqa/")
20
+ async def visual_question_answer(file: UploadFile, question: str = Form(...)):
21
+ """
22
+ Endpoint for visual question answering.
23
+ - file: Upload an image file
24
+ - question: Textual question about the image
25
+ """
26
+ try:
27
+ # Load image
28
+ image = Image.open(file.file).convert("RGB")
29
+
30
+ # Preprocess inputs
31
+ inputs = processor(images=image, text=question, return_tensors="pt")
32
+
33
+ # Get model predictions
34
+ outputs = model(**inputs)
35
+
36
+ # Decode the answer (check model output for correct handling)
37
+ answer = outputs.logits.argmax(dim=-1).item() # Example way to get the answer index
38
+
39
+ # If the output logits contain a mapping, we can return the answer string
40
+ answer_str = processor.decode([answer]) # Assuming you get the answer index from logits
41
+
42
+ # Return JSON response
43
+ return JSONResponse(content={"question": question, "answer": answer_str})
44
 
45
+ except Exception as e:
46
+ return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
47
 
48
+ # Start the FastAPI server
49
+ if __name__ == "__main__":
50
+ uvicorn.run(app, host="0.0.0.0", port=8000)