File size: 3,356 Bytes
726d8f2
fdaa09b
726d8f2
 
 
fdaa09b
 
 
 
 
 
 
 
a150c1b
726d8f2
a150c1b
 
 
fdaa09b
a150c1b
 
fdaa09b
726d8f2
fdaa09b
726d8f2
e93b7cc
726d8f2
 
 
a150c1b
fdaa09b
726d8f2
fdaa09b
726d8f2
 
 
a150c1b
 
fdaa09b
726d8f2
 
fdaa09b
726d8f2
 
 
 
 
a150c1b
fdaa09b
726d8f2
fdaa09b
726d8f2
 
fdaa09b
726d8f2
 
 
a150c1b
 
726d8f2
fdaa09b
726d8f2
 
fdaa09b
726d8f2
 
 
fdaa09b
 
54d49be
726d8f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdaa09b
726d8f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import logging
import gradio as gr
from huggingface_hub import InferenceClient

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

# Environment variables for configuration
HF_TOKEN = os.environ.get("HF_TOKEN", "")
CAPTION_MODEL = os.environ.get("CAPTION_MODEL", "Salesforce/blip-image-captioning-base")
VQA_MODEL = os.environ.get("VQA_MODEL", "dandelin/vilt-b32-finetuned-vqa")

logger.info(f"HF_TOKEN configured: {bool(HF_TOKEN)}")
logger.info(f"CAPTION_MODEL: {CAPTION_MODEL}")
logger.info(f"VQA_MODEL: {VQA_MODEL}")

client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else InferenceClient()
logger.info("InferenceClient initialized")



def caption_image(image):
    """Generate a caption for the image."""
    logger.info(f"caption_image() called | image={image is not None}")
    
    if image is None:
        logger.warning("No image provided")
        return "πŸ“· Upload an image first!"
    
    try:
        logger.info(f"Calling image_to_text | model={CAPTION_MODEL}")
        result = client.image_to_text(image, model=CAPTION_MODEL)
        logger.info(f"Caption: {result.generated_text[:100]}")
        return result.generated_text
    except Exception as e:
        logger.error(f"API error: {e}")
        return f"❌ Error: {e}"


def answer_question(image, question: str):
    """Answer a question about the image."""
    logger.info(f"answer_question() called | image={image is not None} | question={question[:30] if question else 'None'}")
    
    if image is None:
        logger.warning("No image provided")
        return "πŸ“· Upload an image first!"
    if not question.strip():
        logger.warning("No question provided")
        return "❓ Ask a question!"
    
    try:
        logger.info(f"Calling visual_question_answering | model={VQA_MODEL}")
        result = client.visual_question_answering(image=image, question=question, model=VQA_MODEL)
        top = result[0]
        logger.info(f"Answer: {top.answer} ({top.score:.1%})")
        return f"πŸ€– {top.answer} (confidence: {top.score:.1%})"
    except Exception as e:
        logger.error(f"API error: {e}")
        return f"❌ Error: {e}"


logger.info("Building Gradio interface...")

with gr.Blocks(title="Vision Chat") as demo:
    gr.Markdown("# πŸ‘οΈ Vision Chat\nUpload an image, get a caption, and ask questions about it!")

    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            img = gr.Image(type="pil", label="πŸ“· Your Image")
            caption_btn = gr.Button("✨ Generate Caption", variant="primary")

        with gr.Column(scale=1):
            caption_out = gr.Textbox(label="Caption", lines=2, interactive=False)
            question = gr.Textbox(label="❓ Ask a question", placeholder="What color is the animal?")
            ask_btn = gr.Button("Ask", variant="secondary")
            answer_out = gr.Textbox(label="Answer", lines=2, interactive=False)

    caption_btn.click(caption_image, inputs=img, outputs=caption_out)
    ask_btn.click(answer_question, inputs=[img, question], outputs=answer_out)
    question.submit(answer_question, inputs=[img, question], outputs=answer_out)

demo.queue()
logger.info("Starting Gradio server...")
demo.launch()