Fred808 commited on
Commit
76de2e4
·
verified ·
1 Parent(s): 99c4852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -59
app.py CHANGED
@@ -1,93 +1,84 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
- from fastapi.responses import JSONResponse
3
- from PIL import Image
4
- import torch
5
  import io
 
 
 
 
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
- import subprocess
8
 
9
- # Attempt to install flash-attn (if needed)
 
10
  try:
11
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
 
 
 
 
 
12
  except subprocess.CalledProcessError as e:
13
- print(f"Error installing flash-attn: {e}")
14
- print("Continuing without flash-attn.")
15
 
16
- # Determine device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # Load Florence-2 Base
20
- try:
21
- vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True, attn_implementation="eager").to(device).eval()
22
- vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
23
- except Exception as e:
24
- print(f"Error loading base model: {e}")
25
- vision_language_model_base = None
26
- vision_language_processor_base = None
27
-
28
- # Load Florence-2 Large
29
  try:
30
- vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
31
- vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
 
 
 
 
32
  except Exception as e:
33
- print(f"Error loading large model: {e}")
34
- vision_language_model_large = None
35
- vision_language_processor_large = None
36
 
37
- # Initialize FastAPI
38
- app = FastAPI()
39
 
40
  @app.post("/describe-image")
41
- async def describe_image(
42
- file: UploadFile = File(...),
43
- model_choice: str = Form("Base")
44
- ):
45
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
46
- return JSONResponse(status_code=400, content={"error": "Invalid image file type."})
47
 
48
  try:
49
- image_bytes = await file.read()
50
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
51
- except Exception as e:
52
- return JSONResponse(status_code=400, content={"error": f"Failed to process image: {str(e)}"})
53
 
54
- if model_choice == "Base":
55
- if vision_language_model_base is None:
56
- return JSONResponse(status_code=500, content={"error": "Base model not loaded."})
57
- model = vision_language_model_base
58
- processor = vision_language_processor_base
59
- elif model_choice == "Large":
60
- if vision_language_model_large is None:
61
- return JSONResponse(status_code=500, content={"error": "Large model not loaded."})
62
- model = vision_language_model_large
63
- processor = vision_language_processor_large
64
- else:
65
- return JSONResponse(status_code=400, content={"error": "Invalid model choice."})
66
 
67
- try:
68
- inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
69
  with torch.no_grad():
70
  generated_ids = model.generate(
71
  input_ids=inputs["input_ids"],
72
  pixel_values=inputs["pixel_values"],
73
- max_new_tokens=1024,
74
- early_stopping=False,
75
- do_sample=False,
76
- num_beams=3,
77
  )
78
 
79
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
80
- processed_description = processor.post_process_generation(
81
  generated_text,
82
  task="<MORE_DETAILED_CAPTION>",
83
  image_size=(image.width, image.height)
84
  )
85
- image_description = processed_description["<MORE_DETAILED_CAPTION>"]
86
- return JSONResponse(content={"description": image_description})
87
 
88
  except Exception as e:
89
- return JSONResponse(status_code=500, content={"error": f"Image processing failed: {str(e)}"})
90
 
91
  @app.get("/health")
92
  def health():
93
- return {"status": "ok", "device": device}
 
1
+ import os
 
 
 
2
  import io
3
+ import torch
4
+ from PIL import Image
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.responses import JSONResponse
7
  from transformers import AutoProcessor, AutoModelForCausalLM
 
8
 
9
+ # Auto-install flash-attn if needed
10
+ import subprocess
11
  try:
12
+ subprocess.run(
13
+ 'pip install flash-attn --no-build-isolation',
14
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
15
+ check=True,
16
+ shell=True
17
+ )
18
  except subprocess.CalledProcessError as e:
19
+ print(f"Flash-attn install failed: {e}")
20
+ print("Continuing without flash-attn...")
21
 
22
+ # Device setup
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
+ # Load Florence-2-base model and processor
 
 
 
 
 
 
 
 
 
26
  try:
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ 'microsoft/Florence-2-base',
29
+ trust_remote_code=True,
30
+ attn_implementation="eager"
31
+ ).to(device).eval()
32
+ processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
33
  except Exception as e:
34
+ print(f"Error loading Florence-2-base: {e}")
35
+ model = None
36
+ processor = None
37
 
38
+ # FastAPI setup
39
+ app = FastAPI(title="Florence-2 Image Captioning API")
40
 
41
  @app.post("/describe-image")
42
+ async def describe_image(file: UploadFile = File(...)):
43
+ if model is None or processor is None:
44
+ return JSONResponse(status_code=500, content={"error": "Model not loaded"})
45
+
46
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
47
+ return JSONResponse(status_code=400, content={"error": "Invalid file type. Please upload an image."})
48
 
49
  try:
50
+ # Load image from upload
51
+ image_data = await file.read()
52
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
53
 
54
+ # Prepare inputs
55
+ inputs = processor(
56
+ text="<MORE_DETAILED_CAPTION>",
57
+ images=image,
58
+ return_tensors="pt"
59
+ ).to(device)
 
 
 
 
 
 
60
 
61
+ # Generate caption
 
62
  with torch.no_grad():
63
  generated_ids = model.generate(
64
  input_ids=inputs["input_ids"],
65
  pixel_values=inputs["pixel_values"],
66
+ max_new_tokens=512,
67
+ num_beams=3
 
 
68
  )
69
 
70
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
71
+ processed = processor.post_process_generation(
72
  generated_text,
73
  task="<MORE_DETAILED_CAPTION>",
74
  image_size=(image.width, image.height)
75
  )
76
+ description = processed["<MORE_DETAILED_CAPTION>"]
77
+ return {"description": description}
78
 
79
  except Exception as e:
80
+ return JSONResponse(status_code=500, content={"error": str(e)})
81
 
82
  @app.get("/health")
83
  def health():
84
+ return {"status": "ok", "model": "florence-2-base"}