TruongLeThanh commited on
Commit
5b72f69
·
1 Parent(s): b7dc11d
Files changed (1) hide show
  1. app.py +50 -20
app.py CHANGED
@@ -1,31 +1,61 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from PIL import Image
5
  import torch
6
- import io
7
-
8
- app = FastAPI()
9
-
10
- # Load model and processor
11
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
- @app.get("/")
17
- async def root():
18
- return {"message": "BLIP Image Captioning API is running"}
19
-
20
- @app.post("/predict/")
21
- async def predict_caption(file: UploadFile = File(...)):
22
- contents = await file.read()
23
- image = Image.open(io.BytesIO(contents)).convert("RGB")
24
-
25
  inputs = processor(images=image, return_tensors="pt").to(device)
26
  output = model.generate(**inputs, max_new_tokens=20)
27
  caption = processor.decode(output[0], skip_special_tokens=True)
28
-
29
- return JSONResponse(content={"caption": caption})
30
-
31
-
 
 
 
 
 
 
1
+ # from fastapi import FastAPI, File, UploadFile
2
+ # from fastapi.responses import JSONResponse
3
+ # from transformers import BlipProcessor, BlipForConditionalGeneration
4
+ # from PIL import Image
5
+ # import torch
6
+ # import io
7
+
8
+ # app = FastAPI()
9
+
10
+ # # Load model and processor
11
+ # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
+ # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ # model.to(device)
15
+
16
+ # @app.get("/")
17
+ # async def root():
18
+ # return {"message": "BLIP Image Captioning API is running"}
19
+
20
+ # @app.post("/predict/")
21
+ # async def predict_caption(file: UploadFile = File(...)):
22
+ # contents = await file.read()
23
+ # image = Image.open(io.BytesIO(contents)).convert("RGB")
24
+
25
+ # inputs = processor(images=image, return_tensors="pt").to(device)
26
+ # output = model.generate(**inputs, max_new_tokens=20)
27
+ # caption = processor.decode(output[0], skip_special_tokens=True)
28
+
29
+ # return JSONResponse(content={"caption": caption})
30
+
31
+
32
+
33
+
34
+ import gradio as gr
35
+
36
  from transformers import BlipProcessor, BlipForConditionalGeneration
37
  from PIL import Image
38
  import torch
39
+ # Load model
 
 
 
 
40
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
41
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
42
+
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
  model.to(device)
45
 
46
+ # Inference function
47
+ def predict_caption(image):
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
 
 
 
 
 
50
  inputs = processor(images=image, return_tensors="pt").to(device)
51
  output = model.generate(**inputs, max_new_tokens=20)
52
  caption = processor.decode(output[0], skip_special_tokens=True)
53
+ return caption
54
+
55
+ # Gradio UI
56
+ demo = gr.Interface(fn=predict_caption,
57
+ inputs=gr.Image(type="pil"),
58
+ outputs="text",
59
+ title="BLIP Image Captioning",
60
+ description="Tải ảnh lên và nhận mô tả tự động bằng BLIP từ Salesforce.")
61
+ demo.launch()