Fred808 commited on
Commit
99c4852
·
verified ·
1 Parent(s): 1838e15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -62
app.py CHANGED
@@ -1,30 +1,31 @@
1
- import gradio as gr
2
- import subprocess
3
- import torch
4
  from PIL import Image
 
 
5
  from transformers import AutoProcessor, AutoModelForCausalLM
 
6
 
7
- # Attempt to install flash-attn
8
  try:
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
10
  except subprocess.CalledProcessError as e:
11
  print(f"Error installing flash-attn: {e}")
12
  print("Continuing without flash-attn.")
13
 
14
- # Determine the device to use
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Load the base model and processor
18
  try:
19
- vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True,
20
- attn_implementation="eager" ).to(device).eval()
21
  vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
22
  except Exception as e:
23
  print(f"Error loading base model: {e}")
24
  vision_language_model_base = None
25
  vision_language_processor_base = None
26
 
27
- # Load the large model and processor
28
  try:
29
  vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
30
  vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
@@ -33,71 +34,60 @@ except Exception as e:
33
  vision_language_model_large = None
34
  vision_language_processor_large = None
35
 
36
- def describe_image(uploaded_image, model_choice):
37
- """
38
- Generates a detailed description of the input image using the selected model.
39
- Args:
40
- uploaded_image (PIL.Image.Image): The image to describe.
41
- model_choice (str): The model to use, either "Base" or "Large".
42
- Returns:
43
- str: A detailed textual description of the image or an error message.
44
- """
45
- if uploaded_image is None:
46
- return "Please upload an image."
 
 
 
 
 
47
 
48
  if model_choice == "Base":
49
  if vision_language_model_base is None:
50
- return "Base model failed to load."
51
  model = vision_language_model_base
52
  processor = vision_language_processor_base
53
  elif model_choice == "Large":
54
  if vision_language_model_large is None:
55
- return "Large model failed to load."
56
  model = vision_language_model_large
57
  processor = vision_language_processor_large
58
  else:
59
- return "Invalid model choice."
60
 
61
- if not isinstance(uploaded_image, Image.Image):
62
- uploaded_image = Image.fromarray(uploaded_image)
 
 
 
 
 
 
 
 
 
63
 
64
- inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
65
- with torch.no_grad():
66
- generated_ids = model.generate(
67
- input_ids=inputs["input_ids"],
68
- pixel_values=inputs["pixel_values"],
69
- max_new_tokens=1024,
70
- early_stopping=False,
71
- do_sample=False,
72
- num_beams=3,
73
  )
74
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
75
- processed_description = processor.post_process_generation(
76
- generated_text,
77
- task="<MORE_DETAILED_CAPTION>",
78
- image_size=(uploaded_image.width, uploaded_image.height)
79
- )
80
- image_description = processed_description["<MORE_DETAILED_CAPTION>"]
81
- print("\nImage description generated!:", image_description)
82
- return image_description
83
-
84
- # Description for the interface
85
- description = "Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower."
86
- if device == "cpu":
87
- description += " Note: Running on CPU, which may be slow for large models."
88
 
89
- # Create the Gradio interface
90
- image_description_interface = gr.Interface(
91
- fn=describe_image,
92
- inputs=[
93
- gr.Image(label="Upload Image", type="pil"),
94
- gr.Radio(["Base", "Large"], label="Model Choice", value="Base")
95
- ],
96
- outputs=gr.Textbox(label="Generated Caption", lines=4, show_copy_button=True),
97
- live=False,
98
- title="Florence-2 Models Image Captions",
99
- description=description
100
- )
101
 
102
- # Launch the interface
103
- image_description_interface.launch(debug=True, ssr_mode=False)
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.responses import JSONResponse
 
3
  from PIL import Image
4
+ import torch
5
+ import io
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
+ import subprocess
8
 
9
+ # Attempt to install flash-attn (if needed)
10
  try:
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
12
  except subprocess.CalledProcessError as e:
13
  print(f"Error installing flash-attn: {e}")
14
  print("Continuing without flash-attn.")
15
 
16
+ # Determine device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # Load Florence-2 Base
20
  try:
21
+ vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True, attn_implementation="eager").to(device).eval()
 
22
  vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
23
  except Exception as e:
24
  print(f"Error loading base model: {e}")
25
  vision_language_model_base = None
26
  vision_language_processor_base = None
27
 
28
+ # Load Florence-2 Large
29
  try:
30
  vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
31
  vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
 
34
  vision_language_model_large = None
35
  vision_language_processor_large = None
36
 
37
+ # Initialize FastAPI
38
+ app = FastAPI()
39
+
40
+ @app.post("/describe-image")
41
+ async def describe_image(
42
+ file: UploadFile = File(...),
43
+ model_choice: str = Form("Base")
44
+ ):
45
+ if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
46
+ return JSONResponse(status_code=400, content={"error": "Invalid image file type."})
47
+
48
+ try:
49
+ image_bytes = await file.read()
50
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
51
+ except Exception as e:
52
+ return JSONResponse(status_code=400, content={"error": f"Failed to process image: {str(e)}"})
53
 
54
  if model_choice == "Base":
55
  if vision_language_model_base is None:
56
+ return JSONResponse(status_code=500, content={"error": "Base model not loaded."})
57
  model = vision_language_model_base
58
  processor = vision_language_processor_base
59
  elif model_choice == "Large":
60
  if vision_language_model_large is None:
61
+ return JSONResponse(status_code=500, content={"error": "Large model not loaded."})
62
  model = vision_language_model_large
63
  processor = vision_language_processor_large
64
  else:
65
+ return JSONResponse(status_code=400, content={"error": "Invalid model choice."})
66
 
67
+ try:
68
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
69
+ with torch.no_grad():
70
+ generated_ids = model.generate(
71
+ input_ids=inputs["input_ids"],
72
+ pixel_values=inputs["pixel_values"],
73
+ max_new_tokens=1024,
74
+ early_stopping=False,
75
+ do_sample=False,
76
+ num_beams=3,
77
+ )
78
 
79
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
80
+ processed_description = processor.post_process_generation(
81
+ generated_text,
82
+ task="<MORE_DETAILED_CAPTION>",
83
+ image_size=(image.width, image.height)
 
 
 
 
84
  )
85
+ image_description = processed_description["<MORE_DETAILED_CAPTION>"]
86
+ return JSONResponse(content={"description": image_description})
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ except Exception as e:
89
+ return JSONResponse(status_code=500, content={"error": f"Image processing failed: {str(e)}"})
 
 
 
 
 
 
 
 
 
 
90
 
91
+ @app.get("/health")
92
+ def health():
93
+ return {"status": "ok", "device": device}