File size: 4,945 Bytes
731a050
 
 
 
 
 
d7968e8
 
731a050
 
c380201
 
 
 
731a050
a3f5003
731a050
 
 
 
a3f5003
731a050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c380201
 
 
 
 
 
731a050
a3f5003
731a050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c380201
 
 
 
 
 
 
731a050
 
a3f5003
731a050
 
 
 
 
 
 
 
 
 
 
 
a3f5003
731a050
 
 
 
 
a3f5003
731a050
 
a3f5003
731a050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d64b8bb
731a050
 
 
 
 
 
d64b8bb
731a050
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import requests
import tempfile
from pathlib import Path
import secrets
from PIL import Image
import gradio as gr

# Set your Hugging Face API token
HUGGING_FACE_API_KEY = os.getenv("HUGGING_FACE_API_KEY")

if not HUGGING_FACE_API_KEY:
    raise ValueError("Please set the Hugging Face API key in the environment as 'HUGGING_FACE_API_KEY'.")

math_messages = []

# Function to process the image with Hugging Face API
def process_image(image, shouldConvert=False):
    global math_messages
    math_messages = []  # Reset messages when a new image is uploaded

    uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(Path(tempfile.gettempdir()) / "gradio")
    os.makedirs(uploaded_file_dir, exist_ok=True)
    
    name = f"tmp{secrets.token_hex(20)}.jpg"
    filename = os.path.join(uploaded_file_dir, name)
    
    # Save the uploaded image
    if shouldConvert:
        new_img = Image.new('RGB', (image.width, image.height), (255, 255, 255))
        new_img.paste(image, (0, 0), mask=image)
        image = new_img
    image.save(filename)
    
    # Use Hugging Face API for image captioning
    with open(filename, "rb") as img_file:
        response = requests.post(
            "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base",
            headers={"Authorization": f"Bearer {HUGGING_FACE_API_KEY}"},
            files={"file": img_file}
        )
    os.remove(filename)  # Clean up temp file

    # Check if response is successful and handle errors
    if response.status_code == 200:
        caption = response.json().get("generated_text", "No description available.")
    else:
        caption = f"Error: {response.status_code} - {response.json().get('error', 'Unknown error')}"
    return caption

# Function for getting math responses from Hugging Face's text generation API
def get_math_response(image_description, user_question):
    global math_messages
    if not math_messages:
        math_messages.append({"role": "system", "content": "You are a helpful math assistant."})
    
    # Prepare the query content
    content = f"Image description: {image_description}\n\n" if image_description else ""
    query = f"{content}User question: {user_question}"
    math_messages.append({"role": "user", "content": query})
    
    # Make the text generation call
    payload = {
        "inputs": query,
        "parameters": {"max_length": 100, "temperature": 0.7},
    }
    response = requests.post(
        "https://api-inference.huggingface.co/models/gpt2",
        headers={"Authorization": f"Bearer {HUGGING_FACE_API_KEY}"},
        json=payload
    )

    # Check if response is successful and handle errors
    if response.status_code == 200:
        answer = response.json().get("generated_text", "Sorry, I couldn't generate a response.")
    else:
        answer = f"Error: {response.status_code} - {response.json().get('error', 'Unknown error')}"

    yield answer
    math_messages.append({"role": "assistant", "content": answer})

def math_chat_bot(image, sketchpad, question, state):
    current_tab_index = state["tab_index"]
    image_description = None
    
    # Check for uploaded image
    if current_tab_index == 0 and image:
        image_description = process_image(image)
    elif current_tab_index == 1 and sketchpad and sketchpad["composite"]:
        image_description = process_image(sketchpad["composite"], True)
    
    # Get response from the text generation API
    yield from get_math_response(image_description, question)

css = """
#qwen-md .katex-display { display: inline; }
#qwen-md .katex-display>.katex { display: inline; }
#qwen-md .katex-display>.katex>.katex-html { display: inline; }
"""

def tabs_select(e: gr.SelectData, _state):
    _state["tab_index"] = e.index

# Create Gradio interface
with gr.Blocks(css=css) as demo:
    gr.HTML("""<center><font size=8>📖 Math Assistant Demo</center>""")
    state = gr.State({"tab_index": 0})
    with gr.Row():
        with gr.Column():
            with gr.Tabs() as input_tabs:
                with gr.Tab("Upload"):
                    input_image = gr.Image(type="pil", label="Upload")
                with gr.Tab("Sketch"):
                    input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
            input_tabs.select(fn=tabs_select, inputs=[state])
            input_text = gr.Textbox(label="Input your question")
            with gr.Row():
                with gr.Column():
                    clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
                with gr.Column():
                    submit_btn = gr.Button("Submit", variant="primary")
        with gr.Column():
            output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
        submit_btn.click(
            fn=math_chat_bot,
            inputs=[input_image, input_sketchpad, input_text, state],
            outputs=output_md
        )
demo.launch()