SebastianRuff commited on
Commit
b92e4f6
·
verified ·
1 Parent(s): 4ffd51d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -34
handler.py CHANGED
@@ -3,17 +3,18 @@ from PIL import Image
3
  import requests
4
  import torch
5
 
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
- # Check if a GPU is available; use CPU if not
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- # Load the model with trust_remote_code=True
12
  self.model = AutoModelForCausalLM.from_pretrained(
13
  model_dir,
14
  trust_remote_code=True
15
  ).eval().to(device) # Dynamically move to the correct device
16
-
17
  self.processor = AutoProcessor.from_pretrained(
18
  model_dir,
19
  trust_remote_code=True
@@ -21,37 +22,50 @@ class EndpointHandler:
21
  self.device = device
22
 
23
  def __call__(self, data):
24
- # Extract inputs from the request data
25
- task_prompt = data.get("task_prompt", "<MORE_DETAILED_CAPTION>")
26
- image_url = data.get("image_url")
27
-
28
- # Load and process the image
29
- image = self.load_image(image_url)
30
-
31
- # Prepare inputs for the model
32
- inputs = self.processor(
33
- text=task_prompt,
34
- images=image,
35
- return_tensors="pt"
36
- ).to(self.device) # Use the correct device
37
-
38
- # Generate output
39
- generated_ids = self.model.generate(
40
- input_ids=inputs["input_ids"],
41
- pixel_values=inputs["pixel_values"],
42
- max_new_tokens=1024,
43
- num_beams=3,
44
- )
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Decode and post-process the output
47
- generated_text = self.processor.batch_decode(
48
- generated_ids,
49
- skip_special_tokens=True
50
- )[0]
51
 
52
- return {"caption": generated_text}
 
53
 
54
  def load_image(self, image_url):
55
- # Load image from the provided URL
56
- image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
57
- return image
 
 
 
 
 
 
3
  import requests
4
  import torch
5
 
6
+
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
+ # Check if GPU is available, otherwise use CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load the Florence model and processor
13
  self.model = AutoModelForCausalLM.from_pretrained(
14
  model_dir,
15
  trust_remote_code=True
16
  ).eval().to(device) # Dynamically move to the correct device
17
+
18
  self.processor = AutoProcessor.from_pretrained(
19
  model_dir,
20
  trust_remote_code=True
 
22
  self.device = device
23
 
24
  def __call__(self, data):
25
+ try:
26
+ # Extract inputs from the request data
27
+ task_prompt = data.get("task_prompt", "<MORE_DETAILED_CAPTION>")
28
+ image_url = data.get("url") # Match the key sent from n8n
29
+
30
+ if not image_url or not image_url.startswith("http"):
31
+ raise ValueError("Invalid or missing 'url' field. Please provide a valid image URL.")
32
+
33
+ # Load and process the image
34
+ image = self.load_image(image_url)
35
+
36
+ # Prepare inputs for the Florence model
37
+ inputs = self.processor(
38
+ text=task_prompt,
39
+ images=image,
40
+ return_tensors="pt"
41
+ ).to(self.device)
42
+
43
+ # Generate detailed caption using Florence
44
+ generated_ids = self.model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ pixel_values=inputs["pixel_values"],
47
+ max_new_tokens=512, # Adjust token limit for detailed captions
48
+ num_beams=3, # Use beam search for better captions
49
+ early_stopping=True # Stop when the best output is found
50
+ )
51
+
52
+ # Decode the generated text
53
+ generated_text = self.processor.batch_decode(
54
+ generated_ids,
55
+ skip_special_tokens=True
56
+ )[0]
57
 
58
+ return {"caption": generated_text}
 
 
 
 
59
 
60
+ except Exception as e:
61
+ return {"error": str(e)}
62
 
63
  def load_image(self, image_url):
64
+ try:
65
+ # Load image from URL
66
+ response = requests.get(image_url, stream=True)
67
+ response.raise_for_status() # Raise an error for failed requests
68
+ image = Image.open(response.raw).convert("RGB")
69
+ return image
70
+ except Exception as e:
71
+ raise ValueError(f"Failed to load image from URL: {image_url}. Error: {e}")