| import gradio as gr |
| import functools |
| from inference import CaptionGenerator |
|
|
|
|
| @functools.lru_cache(maxsize=1) |
| def load_model(): |
| return CaptionGenerator( |
| model_path="best_model.pth", |
| vocab_path="vocab.pkl", |
| use_vit=True |
| ) |
|
|
|
|
| generator = load_model() |
|
|
|
|
|
|
| def caption_image(img): |
| return generator.generate(img) |
|
|
|
|
| demo = gr.Interface( |
| fn=caption_image, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Textbox(label="Generated Caption"), |
| title="Image Captioning with Transformers", |
| description="Upload an image and get an AI-generated caption." |
| ) |
|
|
| demo.launch() |
|
|