Spaces:
Runtime error
Runtime error
| import requests | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| # Load model and processor | |
| model_id_or_path = "rhymes-ai/Aria" | |
| model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True) | |
| processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True) | |
| # Function to process the input and generate text | |
| def generate_response(image): | |
| # Convert the input image to PIL format (if necessary) | |
| if isinstance(image, str): | |
| image = Image.open(requests.get(image, stream=True).raw) | |
| # Prepare messages for the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"text": None, "type": "image"}, | |
| {"text": "what is the image?", "type": "text"}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=text, images=image, return_tensors="pt") | |
| # Move pixel values to the correct dtype | |
| inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Generate response | |
| with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| stop_strings=["<|im_end|>"], | |
| tokenizer=processor.tokenizer, | |
| do_sample=True, | |
| temperature=0.9, | |
| ) | |
| output_ids = output[0][inputs["input_ids"].shape[1]:] | |
| result = processor.decode(output_ids, skip_special_tokens=True) | |
| return result | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_response, | |
| inputs=gr.inputs.Image(type="filepath"), | |
| outputs="text", | |
| title="Image-to-Text Model", | |
| description="Upload an image, and the model will describe it.", | |
| ) | |
| # Launch the app | |
| iface.launch() | |