Fred808 commited on
Commit
a32396c
·
verified ·
1 Parent(s): 76de2e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -46
app.py CHANGED
@@ -1,84 +1,78 @@
1
- import os
2
  import io
 
3
  import torch
4
  from PIL import Image
5
- from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
 
9
- # Auto-install flash-attn if needed
10
- import subprocess
11
- try:
12
- subprocess.run(
13
- 'pip install flash-attn --no-build-isolation',
14
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
15
- check=True,
16
- shell=True
17
- )
18
- except subprocess.CalledProcessError as e:
19
- print(f"Flash-attn install failed: {e}")
20
- print("Continuing without flash-attn...")
21
-
22
- # Device setup
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
24
 
25
- # Load Florence-2-base model and processor
26
  try:
27
- model = AutoModelForCausalLM.from_pretrained(
28
  'microsoft/Florence-2-base',
29
  trust_remote_code=True,
30
  attn_implementation="eager"
31
  ).to(device).eval()
32
- processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
33
- except Exception as e:
34
- print(f"Error loading Florence-2-base: {e}")
35
- model = None
36
- processor = None
37
 
38
- # FastAPI setup
39
- app = FastAPI(title="Florence-2 Image Captioning API")
 
 
 
 
 
 
40
 
41
  @app.post("/describe-image")
42
  async def describe_image(file: UploadFile = File(...)):
43
- if model is None or processor is None:
44
  return JSONResponse(status_code=500, content={"error": "Model not loaded"})
45
 
46
- if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
47
- return JSONResponse(status_code=400, content={"error": "Invalid file type. Please upload an image."})
48
-
49
  try:
50
- # Load image from upload
51
- image_data = await file.read()
52
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
53
 
54
- # Prepare inputs
55
- inputs = processor(
56
  text="<MORE_DETAILED_CAPTION>",
57
  images=image,
58
  return_tensors="pt"
59
  ).to(device)
60
 
61
- # Generate caption
62
  with torch.no_grad():
63
- generated_ids = model.generate(
64
  input_ids=inputs["input_ids"],
65
  pixel_values=inputs["pixel_values"],
66
- max_new_tokens=512,
67
- num_beams=3
68
  )
69
 
70
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
71
- processed = processor.post_process_generation(
72
  generated_text,
73
  task="<MORE_DETAILED_CAPTION>",
74
- image_size=(image.width, image.height)
75
  )
76
- description = processed["<MORE_DETAILED_CAPTION>"]
77
- return {"description": description}
 
 
 
 
78
 
79
  except Exception as e:
80
  return JSONResponse(status_code=500, content={"error": str(e)})
81
 
82
- @app.get("/health")
83
- def health():
84
- return {"status": "ok", "model": "florence-2-base"}
 
 
 
 
 
 
 
 
1
  import io
2
+ import os
3
  import torch
4
  from PIL import Image
5
+ from fastapi import FastAPI, UploadFile, File
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
 
9
+ # Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ app = FastAPI(title="Florence-2 Base Image Captioning API")
12
 
13
+ # Load Florence-2 base model
14
  try:
15
+ vision_model = AutoModelForCausalLM.from_pretrained(
16
  'microsoft/Florence-2-base',
17
  trust_remote_code=True,
18
  attn_implementation="eager"
19
  ).to(device).eval()
 
 
 
 
 
20
 
21
+ vision_processor = AutoProcessor.from_pretrained(
22
+ 'microsoft/Florence-2-base',
23
+ trust_remote_code=True
24
+ )
25
+ except Exception as e:
26
+ vision_model = None
27
+ vision_processor = None
28
+ print(f"Model loading error: {e}")
29
 
30
  @app.post("/describe-image")
31
  async def describe_image(file: UploadFile = File(...)):
32
+ if vision_model is None or vision_processor is None:
33
  return JSONResponse(status_code=500, content={"error": "Model not loaded"})
34
 
 
 
 
35
  try:
36
+ contents = await file.read()
37
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
 
38
 
39
+ # Preprocess
40
+ inputs = vision_processor(
41
  text="<MORE_DETAILED_CAPTION>",
42
  images=image,
43
  return_tensors="pt"
44
  ).to(device)
45
 
 
46
  with torch.no_grad():
47
+ generated_ids = vision_model.generate(
48
  input_ids=inputs["input_ids"],
49
  pixel_values=inputs["pixel_values"],
50
+ max_new_tokens=1024,
51
+ num_beams=3,
52
  )
53
 
54
+ generated_text = vision_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
55
+ processed = vision_processor.post_process_generation(
56
  generated_text,
57
  task="<MORE_DETAILED_CAPTION>",
58
+ image_size=image.size
59
  )
60
+ caption = processed["<MORE_DETAILED_CAPTION>"]
61
+
62
+ return JSONResponse(content={
63
+ "filename": file.filename,
64
+ "description": caption
65
+ })
66
 
67
  except Exception as e:
68
  return JSONResponse(status_code=500, content={"error": str(e)})
69
 
70
+ @app.get("/")
71
+ def root():
72
+ return {"message": "Florence-2 Base Image Captioning API is running"}
73
+
74
+ # Run the app when executed directly
75
+ if __name__ == "__main__":
76
+ import uvicorn
77
+ port = int(os.getenv("PORT", 7860)) # Spaces set PORT env var
78
+ uvicorn.run("app:app", host="0.0.0.0", port=port)