Fred808 commited on
Commit
061f058
·
verified ·
1 Parent(s): 7c70faf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  from io import BytesIO
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel, HttpUrl
8
- from transformers import AutoProcessor, AutoModelForCausalLM
9
  import uvicorn
10
 
11
  # ===== CONFIG =====
@@ -31,13 +31,14 @@ class ImageAnalysisResponse(BaseModel):
31
  error_message: str = None
32
 
33
  # ===== Load Florence-2 Base Model =====
34
- print("[INFO] Loading Florence-2-base model on CPU...")
35
  try:
36
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- "microsoft/Florence-2-Large",
 
39
  trust_remote_code=True,
40
- attn_implementation="eager"
41
  ).to(DEVICE).eval()
42
  print("[INFO] Model loaded successfully!")
43
  except Exception as e:
@@ -82,14 +83,16 @@ def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") ->
82
  try:
83
  # Resize image for faster processing
84
  image = image.resize(RESIZE_DIM, Image.BILINEAR)
85
-
86
  # Prepare inputs
87
  inputs = processor(
88
  text=task,
89
  images=image,
90
- return_tensors="pt"
 
 
91
  ).to(DEVICE)
92
-
93
  # Generate caption
94
  with torch.no_grad():
95
  generated_ids = model.generate(
@@ -100,21 +103,25 @@ def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") ->
100
  do_sample=False,
101
  repetition_penalty=1.2 # Helps avoid repetitive outputs
102
  )
103
-
104
  # Decode and post-process
105
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
106
  result = processor.post_process_generation(
107
  generated_text,
108
  task=task,
109
  image_size=RESIZE_DIM
110
  )
111
-
112
  if result is None:
 
113
  raise ValueError("Post-processing returned None. The model may not have generated a valid output for the given task.")
114
 
 
115
  return result.get(task, "No caption generated.")
116
-
117
  except Exception as e:
 
118
  raise ValueError(f"Failed to analyze image: {e}")
119
 
120
  # ===== API Endpoints =====
 
5
  from io import BytesIO
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel, HttpUrl
8
+ from transformers import AutoProcessor, AutoModelForVision2Seq
9
  import uvicorn
10
 
11
  # ===== CONFIG =====
 
31
  error_message: str = None
32
 
33
  # ===== Load Florence-2 Base Model =====
34
+ print("[INFO] Loading Florence-2 model on CPU...")
35
  try:
36
+ MODEL_ID = "microsoft/Florence-2-large"
37
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
38
+ model = AutoModelForVision2Seq.from_pretrained(
39
+ MODEL_ID,
40
  trust_remote_code=True,
41
+ torch_dtype=torch.float32
42
  ).to(DEVICE).eval()
43
  print("[INFO] Model loaded successfully!")
44
  except Exception as e:
 
83
  try:
84
  # Resize image for faster processing
85
  image = image.resize(RESIZE_DIM, Image.BILINEAR)
86
+
87
  # Prepare inputs
88
  inputs = processor(
89
  text=task,
90
  images=image,
91
+ return_tensors="pt",
92
+ padding=True,
93
+ truncation=True
94
  ).to(DEVICE)
95
+
96
  # Generate caption
97
  with torch.no_grad():
98
  generated_ids = model.generate(
 
103
  do_sample=False,
104
  repetition_penalty=1.2 # Helps avoid repetitive outputs
105
  )
106
+
107
  # Decode and post-process
108
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
+ print(f"[DEBUG] Generated text: {generated_text}")
110
  result = processor.post_process_generation(
111
  generated_text,
112
  task=task,
113
  image_size=RESIZE_DIM
114
  )
115
+
116
  if result is None:
117
+ print("[ERROR] Post-processing returned None. The model may not have generated a valid output for the given task.")
118
  raise ValueError("Post-processing returned None. The model may not have generated a valid output for the given task.")
119
 
120
+ print(f"[DEBUG] Post-processed result: {result}")
121
  return result.get(task, "No caption generated.")
122
+
123
  except Exception as e:
124
+ print(f"[ERROR] Exception in analyze_image: {e}")
125
  raise ValueError(f"Failed to analyze image: {e}")
126
 
127
  # ===== API Endpoints =====