Rishabh2234 commited on
Commit
f90ca7e
·
1 Parent(s): 87cea5c

files for inference generation

Browse files
Files changed (2) hide show
  1. app.py +2 -11
  2. model.py +1 -1
app.py CHANGED
@@ -4,14 +4,12 @@ from PIL import Image
4
  import torchvision.transforms as transforms
5
  import torch
6
  from model import load_model
7
- import os
8
 
9
- # Initialize FastAPI
10
  app = FastAPI()
11
 
12
- # Set device and checkpoint path (use a relative path so it stays within your Space's storage)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- checkpoint_path = "checkpoint.pth"
15
 
16
  # Load the model and tokenizer
17
  model, tokenizer = load_model(checkpoint_path, device)
@@ -30,18 +28,11 @@ def read_root():
30
  @app.post("/generate_caption/")
31
  async def generate_caption(file: UploadFile = File(...)):
32
  try:
33
- # Read image file from the request
34
  contents = await file.read()
35
  image = Image.open(io.BytesIO(contents)).convert("RGB")
36
-
37
- # Preprocess the image
38
  image_tensor = transform(image).unsqueeze(0).to(device)
39
-
40
- # Generate caption using your model's generate method
41
  output_ids = model.generate(pixel_values=image_tensor, max_length=30, num_beams=4)
42
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
-
44
  return {"caption": caption}
45
-
46
  except Exception as e:
47
  return {"error": str(e)}
 
4
  import torchvision.transforms as transforms
5
  import torch
6
  from model import load_model
 
7
 
 
8
  app = FastAPI()
9
 
10
+ # Set device and use a writable checkpoint path (e.g., /tmp)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ checkpoint_path = "/tmp/checkpoint.pth" # Updated path
13
 
14
  # Load the model and tokenizer
15
  model, tokenizer = load_model(checkpoint_path, device)
 
28
  @app.post("/generate_caption/")
29
  async def generate_caption(file: UploadFile = File(...)):
30
  try:
 
31
  contents = await file.read()
32
  image = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
33
  image_tensor = transform(image).unsqueeze(0).to(device)
 
 
34
  output_ids = model.generate(pixel_values=image_tensor, max_length=30, num_beams=4)
35
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
36
  return {"caption": caption}
 
37
  except Exception as e:
38
  return {"error": str(e)}
model.py CHANGED
@@ -55,7 +55,6 @@ class ViTT5(nn.Module):
55
  temperature=0.9, # More diverse outputs
56
  **kwargs
57
  )
58
-
59
  def download_checkpoint(checkpoint_path):
60
  """
61
  Downloads the checkpoint from Hugging Face Model Hub if not found locally.
@@ -72,6 +71,7 @@ def download_checkpoint(checkpoint_path):
72
  else:
73
  raise RuntimeError(f"Error downloading model, status code: {response.status_code}")
74
 
 
75
  def load_model(checkpoint_path, device):
76
  """
77
  Loads the ViTT5 model along with the T5 tokenizer.
 
55
  temperature=0.9, # More diverse outputs
56
  **kwargs
57
  )
 
58
  def download_checkpoint(checkpoint_path):
59
  """
60
  Downloads the checkpoint from Hugging Face Model Hub if not found locally.
 
71
  else:
72
  raise RuntimeError(f"Error downloading model, status code: {response.status_code}")
73
 
74
+
75
  def load_model(checkpoint_path, device):
76
  """
77
  Loads the ViTT5 model along with the T5 tokenizer.