SebastianRuff commited on
Commit
68a8f79
·
verified ·
1 Parent(s): b92e4f6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -31
handler.py CHANGED
@@ -3,18 +3,13 @@ from PIL import Image
3
  import requests
4
  import torch
5
 
6
-
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
- # Check if GPU is available, otherwise use CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- # Load the Florence model and processor
13
  self.model = AutoModelForCausalLM.from_pretrained(
14
  model_dir,
15
  trust_remote_code=True
16
- ).eval().to(device) # Dynamically move to the correct device
17
-
18
  self.processor = AutoProcessor.from_pretrained(
19
  model_dir,
20
  trust_remote_code=True
@@ -23,49 +18,43 @@ class EndpointHandler:
23
 
24
  def __call__(self, data):
25
  try:
26
- # Extract inputs from the request data
27
- task_prompt = data.get("task_prompt", "<MORE_DETAILED_CAPTION>")
28
- image_url = data.get("url") # Match the key sent from n8n
29
-
30
  if not image_url or not image_url.startswith("http"):
31
- raise ValueError("Invalid or missing 'url' field. Please provide a valid image URL.")
32
-
33
- # Load and process the image
34
  image = self.load_image(image_url)
35
-
36
- # Prepare inputs for the Florence model
37
- inputs = self.processor(
38
- text=task_prompt,
39
  images=image,
40
  return_tensors="pt"
41
  ).to(self.device)
42
 
43
- # Generate detailed caption using Florence
44
  generated_ids = self.model.generate(
45
- input_ids=inputs["input_ids"],
46
- pixel_values=inputs["pixel_values"],
47
- max_new_tokens=512, # Adjust token limit for detailed captions
48
- num_beams=3, # Use beam search for better captions
49
- early_stopping=True # Stop when the best output is found
 
50
  )
51
 
52
- # Decode the generated text
53
  generated_text = self.processor.batch_decode(
54
  generated_ids,
55
  skip_special_tokens=True
56
  )[0]
57
-
58
  return {"caption": generated_text}
59
-
60
  except Exception as e:
61
  return {"error": str(e)}
62
 
63
  def load_image(self, image_url):
64
  try:
65
- # Load image from URL
66
  response = requests.get(image_url, stream=True)
67
- response.raise_for_status() # Raise an error for failed requests
68
- image = Image.open(response.raw).convert("RGB")
69
- return image
70
  except Exception as e:
71
- raise ValueError(f"Failed to load image from URL: {image_url}. Error: {e}")
 
3
  import requests
4
  import torch
5
 
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
  model_dir,
11
  trust_remote_code=True
12
+ ).eval().to(device)
 
13
  self.processor = AutoProcessor.from_pretrained(
14
  model_dir,
15
  trust_remote_code=True
 
18
 
19
  def __call__(self, data):
20
  try:
21
+ inputs_data = data.get("inputs", {})
22
+ params = data.get("parameters", {})
23
+
24
+ image_url = inputs_data.get("url")
25
  if not image_url or not image_url.startswith("http"):
26
+ raise ValueError("Invalid or missing 'url' field")
27
+
 
28
  image = self.load_image(image_url)
29
+ model_inputs = self.processor(
30
+ text=inputs_data.get("task_prompt", "<MORE_DETAILED_CAPTION>"),
 
 
31
  images=image,
32
  return_tensors="pt"
33
  ).to(self.device)
34
 
 
35
  generated_ids = self.model.generate(
36
+ input_ids=model_inputs["input_ids"],
37
+ pixel_values=model_inputs["pixel_values"],
38
+ max_new_tokens=params.get("max_new_tokens", 512),
39
+ num_beams=params.get("num_beams", 3),
40
+ early_stopping=params.get("early_stopping", True),
41
+ do_sample=params.get("do_sample", False)
42
  )
43
 
 
44
  generated_text = self.processor.batch_decode(
45
  generated_ids,
46
  skip_special_tokens=True
47
  )[0]
48
+
49
  return {"caption": generated_text}
50
+
51
  except Exception as e:
52
  return {"error": str(e)}
53
 
54
  def load_image(self, image_url):
55
  try:
 
56
  response = requests.get(image_url, stream=True)
57
+ response.raise_for_status()
58
+ return Image.open(response.raw).convert("RGB")
 
59
  except Exception as e:
60
+ raise ValueError(f"Failed to load image: {str(e)}")