staghado's picture
Update app.py
d5a7e96 verified
raw
history blame
4.72 kB
#!/usr/bin/env python3
import os
import json
import base64
import requests
import gradio as gr
from PIL import Image
from io import BytesIO
# Get environment variables from HF Spaces secrets
ENDPOINT = os.environ.get("VLLM_ENDPOINT")
MODEL = os.environ.get("VLLM_MODEL")
if not ENDPOINT or not MODEL:
raise ValueError("VLLM_ENDPOINT and VLLM_MODEL environment variables must be set. Please add them as secrets in your Space settings.")
def image_to_base64(image):
"""Convert PIL Image to base64 string."""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def process_image(image, temperature):
"""
Send image to vLLM endpoint and stream the response.
"""
if image is None:
yield "Please upload an image first.", ""
return
# Convert image to base64
b64_image = image_to_base64(image)
# Build the payload with only image input (no text prompt)
payload = {
"model": MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": ""},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_image}"}}
]
}
],
"temperature": temperature,
"stream": True
}
try:
response = requests.post(
ENDPOINT,
headers={"Content-Type": "application/json"},
data=json.dumps(payload),
stream=True
)
response.raise_for_status()
accumulated_response = ""
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if line.strip() == '[DONE]':
break
try:
chunk = json.loads(line)
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
accumulated_response += content
yield accumulated_response, accumulated_response
except json.JSONDecodeError:
continue
except Exception as e:
yield f"Error: {str(e)}", f"Error: {str(e)}"
# Build the Gradio Interface
with gr.Blocks(title="πŸ“– Image OCR", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ“– Image to Text Extraction
**πŸ’‘ How to use:**
1. Upload an image using the upload box
2. Adjust temperature if needed
3. Click "Extract Text" to process
The model will extract and format text from your image.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="πŸ–ΌοΈ Upload Image",
sources=["upload", "clipboard"],
height=400
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.15,
step=0.05,
label="Temperature"
)
submit_btn = gr.Button("Extract Text", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column():
output_text = gr.Markdown(
label="πŸ“„ Extracted Text (Rendered)",
value="<div style='min-height: 400px; padding: 10px; border: 1px solid #e0e0e0; border-radius: 4px; background-color: #f9f9f9;'><em>Extracted text will appear here...</em></div>",
height=500
)
with gr.Row():
with gr.Column():
raw_output = gr.Textbox(
label="Raw Markdown Output",
placeholder="Raw text will appear here...",
lines=15,
show_copy_button=True
)
# Event handlers
submit_btn.click(
fn=process_image,
inputs=[image_input, temperature],
outputs=[output_text, raw_output]
)
clear_btn.click(
fn=lambda: (None, "", ""),
outputs=[image_input, output_text, raw_output]
)
gr.Markdown("""
---
**Note:** Configure endpoint via `VLLM_ENDPOINT` and `VLLM_MODEL` environment variables.
""")
if __name__ == "__main__":
demo.launch()