arshadrana commited on
Commit
731a050
·
verified ·
1 Parent(s): a3f5003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -55
app.py CHANGED
@@ -1,64 +1,116 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
5
- # Use a pipeline as a high-level helper
6
- from transformers import pipeline
 
7
 
8
- messages = [
9
- {"role": "user", "content": "Who are you?"},
10
- ]
11
- pipe = pipeline("text-generation", model="Qwen/Qwen2.5-Math-1.5B")
12
- pipe(messages)
13
- # Load the model and tokenizer
14
- # model_name = "Qwen/Qwen2-Math-1.5B"
15
- # device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # model = AutoModelForCausalLM.from_pretrained(
18
- # model_name,
19
- # torch_dtype="auto",
20
- # device_map="auto"
21
- # ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # # Define a function for Gradio to handle user input
26
- # def solve_math(prompt):
27
- # messages = [
28
- # {"role": "system", "content": "You are a helpful assistant."},
29
- # {"role": "user", "content": prompt}
30
- # ]
31
- # text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
- # model_inputs = tokenizer([text], return_tensors="pt").to(device)
 
 
 
 
33
 
34
- # generation_config = GenerationConfig(
35
- # do_sample=False, # For greedy decoding
36
- # max_new_tokens=512
37
- # )
 
38
 
39
- # generated_ids = model.generate(
40
- # **model_inputs,
41
- # generation_config=generation_config
42
- # )
43
 
44
- # # Remove the input tokens from the output
45
- # generated_ids = [
46
- # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
47
- # ]
48
-
49
- # # Decode the generated output and return the result
50
- # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
51
- # return response
52
-
53
- # # Create the Gradio interface
54
- # iface = gr.Interface(
55
- # fn=solve_math, # Function to call
56
- # inputs="text", # Text input for the user prompt
57
- # outputs="text", # Text output for the model's response
58
- # title="Math Solver", # App title
59
- # description="Provide a math problem and the model will solve it."
60
- # )
61
-
62
- # Launch the app
63
- if __name__ == "__main__":
64
- iface.launch()
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tempfile
4
+ from pathlib import Path
5
+ import secrets
6
+ from PIL import Image
7
  import gradio as gr
 
 
8
 
9
+ # Set your Hugging Face API token
10
+ HUGGING_FACE_API_KEY = os.getenv("HUGGING_FACE_API_KEY")
11
+ math_messages = []
12
 
13
+ # Function to process the image with Hugging Face API
14
+ def process_image(image, shouldConvert=False):
15
+ global math_messages
16
+ math_messages = [] # Reset messages when a new image is uploaded
 
 
 
 
17
 
18
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(Path(tempfile.gettempdir()) / "gradio")
19
+ os.makedirs(uploaded_file_dir, exist_ok=True)
20
+
21
+ name = f"tmp{secrets.token_hex(20)}.jpg"
22
+ filename = os.path.join(uploaded_file_dir, name)
23
+
24
+ # Save the uploaded image
25
+ if shouldConvert:
26
+ new_img = Image.new('RGB', (image.width, image.height), (255, 255, 255))
27
+ new_img.paste(image, (0, 0), mask=image)
28
+ image = new_img
29
+ image.save(filename)
30
+
31
+ # Use Hugging Face API for image captioning
32
+ with open(filename, "rb") as img_file:
33
+ response = requests.post(
34
+ "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base",
35
+ headers={"Authorization": f"Bearer {HUGGING_FACE_API_KEY}"},
36
+ files={"file": img_file}
37
+ )
38
+ os.remove(filename) # Clean up temp file
39
+
40
+ caption = response.json().get("generated_text", "No description available.")
41
+ return caption
42
 
43
+ # Function for getting math responses from Hugging Face's text generation API
44
+ def get_math_response(image_description, user_question):
45
+ global math_messages
46
+ if not math_messages:
47
+ math_messages.append({"role": "system", "content": "You are a helpful math assistant."})
48
+
49
+ # Prepare the query content
50
+ content = f"Image description: {image_description}\n\n" if image_description else ""
51
+ query = f"{content}User question: {user_question}"
52
+ math_messages.append({"role": "user", "content": query})
53
+
54
+ # Make the text generation call
55
+ payload = {
56
+ "inputs": query,
57
+ "parameters": {"max_length": 100, "temperature": 0.7},
58
+ }
59
+ response = requests.post(
60
+ "https://api-inference.huggingface.co/models/gpt2",
61
+ headers={"Authorization": f"Bearer {HUGGING_FACE_API_KEY}"},
62
+ json=payload
63
+ )
64
+
65
+ answer = response.json().get("generated_text", "Sorry, I couldn't generate a response.")
66
+ yield answer
67
+ math_messages.append({"role": "assistant", "content": answer})
68
 
69
+ def math_chat_bot(image, sketchpad, question, state):
70
+ current_tab_index = state["tab_index"]
71
+ image_description = None
72
+
73
+ # Check for uploaded image
74
+ if current_tab_index == 0 and image:
75
+ image_description = process_image(image)
76
+ elif current_tab_index == 1 and sketchpad and sketchpad["composite"]:
77
+ image_description = process_image(sketchpad["composite"], True)
78
+
79
+ # Get response from the text generation API
80
+ yield from get_math_response(image_description, question)
81
 
82
+ css = """
83
+ #qwen-md .katex-display { display: inline; }
84
+ #qwen-md .katex-display>.katex { display: inline; }
85
+ #qwen-md .katex-display>.katex>.katex-html { display: inline; }
86
+ """
87
 
88
+ def tabs_select(e: gr.SelectData, _state):
89
+ _state["tab_index"] = e.index
 
 
90
 
91
+ # Create Gradio interface
92
+ with gr.Blocks(css=css) as demo:
93
+ gr.HTML("""<center><font size=8>📖 Math Assistant Demo</center>""")
94
+ state = gr.State({"tab_index": 0})
95
+ with gr.Row():
96
+ with gr.Column():
97
+ with gr.Tabs() as input_tabs:
98
+ with gr.Tab("Upload"):
99
+ input_image = gr.Image(type="pil", label="Upload")
100
+ with gr.Tab("Sketch"):
101
+ input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
102
+ input_tabs.select(fn=tabs_select, inputs=[state])
103
+ input_text = gr.Textbox(label="Input your question")
104
+ with gr.Row():
105
+ with gr.Column():
106
+ clear_btn = gr.ClearButton([*input_image, input_sketchpad, input_text])
107
+ with gr.Column():
108
+ submit_btn = gr.Button("Submit", variant="primary")
109
+ with gr.Column():
110
+ output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
111
+ submit_btn.click(
112
+ fn=math_chat_bot,
113
+ inputs=[*input_image, input_sketchpad, input_text, state],
114
+ outputs=output_md
115
+ )
116
+ demo.launch()