|
|
import gradio as gr |
|
|
import subprocess |
|
|
import torch |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import urllib.parse |
|
|
|
|
|
|
|
|
try: |
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error installing flash-attn: {e}") |
|
|
print("Continuing without flash-attn.") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
try: |
|
|
vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval() |
|
|
vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True) |
|
|
print("β Base model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading base model: {e}") |
|
|
vision_language_model_base = None |
|
|
vision_language_processor_base = None |
|
|
|
|
|
|
|
|
try: |
|
|
vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval() |
|
|
vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True) |
|
|
print("β Large model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading large model: {e}") |
|
|
vision_language_model_large = None |
|
|
vision_language_processor_large = None |
|
|
|
|
|
def load_image_from_url(image_url): |
|
|
"""Load an image from a URL.""" |
|
|
try: |
|
|
response = requests.get(image_url, timeout=30) |
|
|
response.raise_for_status() |
|
|
image = Image.open(BytesIO(response.content)) |
|
|
return image.convert('RGB') |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error loading image from URL: {e}") |
|
|
|
|
|
def process_image_description(model, processor, image): |
|
|
"""Process an image and generate description using the specified model.""" |
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1024, |
|
|
early_stopping=False, |
|
|
do_sample=False, |
|
|
num_beams=3, |
|
|
) |
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
processed_description = processor.post_process_generation( |
|
|
generated_text, |
|
|
task="<MORE_DETAILED_CAPTION>", |
|
|
image_size=(image.width, image.height) |
|
|
) |
|
|
image_description = processed_description["<MORE_DETAILED_CAPTION>"] |
|
|
return image_description |
|
|
|
|
|
def describe_image(uploaded_image, model_choice): |
|
|
"""Generate description from uploaded image.""" |
|
|
if uploaded_image is None: |
|
|
return "Please upload an image." |
|
|
|
|
|
if model_choice == "Florence-2-base": |
|
|
if vision_language_model_base is None: |
|
|
return "Base model failed to load." |
|
|
model = vision_language_model_base |
|
|
processor = vision_language_processor_base |
|
|
elif model_choice == "Florence-2-large": |
|
|
if vision_language_model_large is None: |
|
|
return "Large model failed to load." |
|
|
model = vision_language_model_large |
|
|
processor = vision_language_processor_large |
|
|
else: |
|
|
return "Invalid model choice." |
|
|
|
|
|
try: |
|
|
return process_image_description(model, processor, uploaded_image) |
|
|
except Exception as e: |
|
|
return f"Error generating caption: {str(e)}" |
|
|
|
|
|
def describe_image_from_url(image_url, model_choice): |
|
|
"""Generate description from image URL.""" |
|
|
try: |
|
|
if not image_url: |
|
|
return {"error": "image_url is required"} |
|
|
|
|
|
if model_choice not in ["Florence-2-base", "Florence-2-large"]: |
|
|
return {"error": "Invalid model choice. Use 'Florence-2-base' or 'Florence-2-large'"} |
|
|
|
|
|
|
|
|
image = load_image_from_url(image_url) |
|
|
|
|
|
|
|
|
if model_choice == "Florence-2-base": |
|
|
if vision_language_model_base is None: |
|
|
return {"error": "Base model not available"} |
|
|
model = vision_language_model_base |
|
|
processor = vision_language_processor_base |
|
|
else: |
|
|
if vision_language_model_large is None: |
|
|
return {"error": "Large model not available"} |
|
|
model = vision_language_model_large |
|
|
processor = vision_language_processor_large |
|
|
|
|
|
|
|
|
caption = process_image_description(model, processor, image) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": model_choice, |
|
|
"caption": caption, |
|
|
"image_size": {"width": image.width, "height": image.height} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Error processing image: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGE_SERVER_BASE = os.getenv("IMAGE_SERVER_BASE", "https://fred808-vssee.hf.space") |
|
|
DATA_COLLECTION_BASE = os.getenv("DATA_COLLECTION_BASE", "https://fred808-flow.hf.space") |
|
|
REQUESTER_ID = os.getenv("FLO_REQUESTER_ID", f"florence-2-{os.getpid()}") |
|
|
MODEL_CHOICE = os.getenv("FLO_MODEL_CHOICE", "Florence-2-base") |
|
|
|
|
|
|
|
|
def _build_download_url(course: str, video: str, frame: str) -> str: |
|
|
file_param = f"frame:{course}/{video}/{frame}" |
|
|
return f"{IMAGE_SERVER_BASE.rstrip('/')}/download?course={urllib.parse.quote(course, safe='')}&file={urllib.parse.quote(file_param, safe='') }" |
|
|
|
|
|
|
|
|
def _download_bytes(url: str, timeout: int = 30): |
|
|
try: |
|
|
r = requests.get(url, timeout=timeout) |
|
|
r.raise_for_status() |
|
|
return r.content, r.headers.get('content-type') |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] download failed {url}: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def _post_submit(caption: str, image_name: str, course: str, image_url: str, image_bytes: bytes): |
|
|
submit_url = f"{DATA_COLLECTION_BASE.rstrip('/')}/submit" |
|
|
files = {'image': (image_name, image_bytes, 'application/octet-stream')} |
|
|
data = {'caption': caption, 'image_name': image_name, 'course': course, 'image_url': image_url} |
|
|
try: |
|
|
r = requests.post(submit_url, data=data, files=files, timeout=30) |
|
|
try: |
|
|
return r.status_code, r.json() |
|
|
except Exception: |
|
|
return r.status_code, r.text |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] submit POST failed: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def _release_frame(course: str, video: str, frame: str): |
|
|
try: |
|
|
release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/frame/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(video, safe='')}/{urllib.parse.quote(frame, safe='')}" |
|
|
requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10) |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] release frame failed: {e}") |
|
|
|
|
|
|
|
|
def _release_course(course: str): |
|
|
try: |
|
|
release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/course/{urllib.parse.quote(course, safe='')}" |
|
|
requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10) |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] release course failed: {e}") |
|
|
|
|
|
|
|
|
def background_worker(): |
|
|
print("[BACKGROUND] Worker waiting for model to be available...") |
|
|
|
|
|
waited = 0 |
|
|
while waited < 120: |
|
|
if MODEL_CHOICE == "Florence-2-base": |
|
|
if vision_language_model_base is not None and vision_language_processor_base is not None: |
|
|
break |
|
|
else: |
|
|
if vision_language_model_large is not None and vision_language_processor_large is not None: |
|
|
break |
|
|
time.sleep(1) |
|
|
waited += 1 |
|
|
|
|
|
if waited >= 120: |
|
|
print("[BACKGROUND] Model not available after timeout; background worker exiting.") |
|
|
return |
|
|
|
|
|
print("[BACKGROUND] Model loaded; starting polling loop") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
try: |
|
|
r = requests.get(f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/next/course", params={"requester_id": REQUESTER_ID}, timeout=15) |
|
|
if r.status_code == 404: |
|
|
time.sleep(3) |
|
|
continue |
|
|
r.raise_for_status() |
|
|
course_json = r.json() |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] failed to get next course: {e}") |
|
|
time.sleep(3) |
|
|
continue |
|
|
|
|
|
course = course_json.get('course_id') or course_json.get('course') |
|
|
if not course: |
|
|
print(f"[BACKGROUND] invalid course response: {course_json}") |
|
|
time.sleep(2) |
|
|
continue |
|
|
|
|
|
print(f"[BACKGROUND] processing course: {course}") |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
img_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/next/image/{urllib.parse.quote(course, safe='')}" |
|
|
rimg = requests.get(img_url, params={"requester_id": REQUESTER_ID}, timeout=15) |
|
|
if rimg.status_code == 404: |
|
|
print(f"[BACKGROUND] no images for course {course}") |
|
|
break |
|
|
rimg.raise_for_status() |
|
|
img_json = rimg.json() |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] failed to get next image: {e}") |
|
|
time.sleep(1) |
|
|
continue |
|
|
|
|
|
video = img_json.get('video') |
|
|
frame = img_json.get('frame') |
|
|
file_id = img_json.get('file_id') |
|
|
if not (video and frame and file_id): |
|
|
print(f"[BACKGROUND] unexpected image entry: {img_json}") |
|
|
time.sleep(0.5) |
|
|
continue |
|
|
|
|
|
download_url = _build_download_url(course, video, frame) |
|
|
print(f"[BACKGROUND] downloading {download_url}") |
|
|
img_bytes, content_type = _download_bytes(download_url) |
|
|
if not img_bytes: |
|
|
print(f"[BACKGROUND] failed to download image, releasing frame {file_id}") |
|
|
_release_frame(course, video, frame) |
|
|
time.sleep(1) |
|
|
continue |
|
|
|
|
|
try: |
|
|
pil_img = Image.open(BytesIO(img_bytes)).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] failed to open image bytes: {e}") |
|
|
_release_frame(course, video, frame) |
|
|
time.sleep(1) |
|
|
continue |
|
|
|
|
|
|
|
|
if MODEL_CHOICE == "Florence-2-base": |
|
|
model = vision_language_model_base |
|
|
processor = vision_language_processor_base |
|
|
else: |
|
|
model = vision_language_model_large |
|
|
processor = vision_language_processor_large |
|
|
|
|
|
caption = "" |
|
|
try: |
|
|
|
|
|
caption = process_image_description(model, processor, pil_img) |
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] captioning failed: {e}") |
|
|
|
|
|
status, resp = _post_submit(caption, frame, course, download_url, img_bytes) |
|
|
print(f"[BACKGROUND] submitted caption for {frame}: status={status}") |
|
|
|
|
|
|
|
|
_release_frame(course, video, frame) |
|
|
time.sleep(0.2) |
|
|
|
|
|
|
|
|
_release_course(course) |
|
|
time.sleep(1) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[BACKGROUND] unexpected loop error: {e}") |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
def _start_worker_thread(): |
|
|
t = threading.Thread(target=background_worker, daemon=True) |
|
|
t.start() |
|
|
|
|
|
|
|
|
|
|
|
description = "> Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower." |
|
|
if device == "cpu": |
|
|
description += " Note: Running on CPU, which may be slow for large models." |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "Florence-2-large"], |
|
|
["https://huggingface.co/spaces/Fred808/NNE/resolve/main/young-woman-doing-fencing-special-equipment.jpg", "Florence-2-base"], |
|
|
] |
|
|
|
|
|
css = """ |
|
|
.submit-btn { |
|
|
background-color: #4682B4 !important; |
|
|
color: white !important; |
|
|
} |
|
|
.submit-btn:hover { |
|
|
background-color: #87CEEB !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
|
gr.Markdown("# Florence-2 Models Image Captions") |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Tab("Upload Image"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(label="Upload Image", type="pil") |
|
|
generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn") |
|
|
|
|
|
with gr.Column(): |
|
|
model_choice = gr.Radio( |
|
|
["Florence-2-base", "Florence-2-large"], |
|
|
label="Model Choice", |
|
|
value="Florence-2-base" |
|
|
) |
|
|
output = gr.Textbox(label="Generated Caption", lines=4, show_copy_button=True) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[image_input, model_choice], |
|
|
outputs=[output], |
|
|
fn=describe_image, |
|
|
run_on_click=True |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=describe_image, |
|
|
inputs=[image_input, model_choice], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
with gr.Tab("Image from URL"): |
|
|
gr.Markdown("## Generate caption from image URL") |
|
|
gr.Markdown("Enter an image URL below to generate a caption.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
url_input = gr.Textbox( |
|
|
label="Image URL", |
|
|
placeholder="https://example.com/image.jpg", |
|
|
lines=2 |
|
|
) |
|
|
url_model_choice = gr.Radio( |
|
|
["Florence-2-base", "Florence-2-large"], |
|
|
label="Model Choice", |
|
|
value="Florence-2-large" |
|
|
) |
|
|
url_generate_btn = gr.Button("Generate Caption from URL", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
url_output = gr.JSON(label="API Response") |
|
|
url_caption = gr.Textbox(label="Caption", lines=4, show_copy_button=True) |
|
|
|
|
|
|
|
|
url_examples = [ |
|
|
["https://huggingface.co/spaces/Fred808/NNE/resolve/main/young-woman-doing-fencing-special-equipment.jpg", "Florence-2-large"], |
|
|
["https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "Florence-2-base"], |
|
|
] |
|
|
|
|
|
gr.Examples( |
|
|
examples=url_examples, |
|
|
inputs=[url_input, url_model_choice], |
|
|
outputs=[url_output, url_caption], |
|
|
fn=describe_image_from_url, |
|
|
run_on_click=True |
|
|
) |
|
|
|
|
|
def process_url_request(image_url, model_choice): |
|
|
result = describe_image_from_url(image_url, model_choice) |
|
|
caption = result.get("caption", "") if "caption" in result else result.get("error", "") |
|
|
return result, caption |
|
|
|
|
|
url_generate_btn.click( |
|
|
fn=process_url_request, |
|
|
inputs=[url_input, url_model_choice], |
|
|
outputs=[url_output, url_caption] |
|
|
) |
|
|
|
|
|
|
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
|
|
|
|
|
|
try: |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=port, |
|
|
share=False, |
|
|
debug=False, |
|
|
show_error=True, |
|
|
quiet=True, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error launching app: {e}") |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=port, share=False, quiet=True) |