Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import gradio as gr | |
| # Load the model and tokenizer | |
| model_id = "HuggingFaceH4/zephyr-7b-beta" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| # Define the image-to-prompt function | |
| def image_to_prompt(input_image, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2): | |
| # Use a vision model to extract image features | |
| vision_model = pipeline('feature-extraction', model='google/vit-base-patch16-224-in21k') | |
| features = vision_model(input_image) | |
| # Convert features to text prompts | |
| prompt = "Describe the image in detail: " + str(features) | |
| return prompt | |
| # Define the text generation function | |
| def generate_text(prompt, enable_nsfw=False): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Set model parameters | |
| model_kwargs = { | |
| "max_new_tokens": 200, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "top_k": 50, | |
| "pad_token_id": tokenizer.eos_token_id | |
| } | |
| if enable_nsfw: | |
| model.config.nsfw = True | |
| # Generate text | |
| outputs = model.generate(**inputs, **model_kwargs) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Define the Gradio interface | |
| def gradio_demo(input_image, enable_nsfw): | |
| prompt = image_to_prompt(input_image) | |
| response = generate_text(prompt, enable_nsfw) | |
| return response | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=gradio_demo, | |
| inputs=[ | |
| gr.Image(label="Input Image"), | |
| gr.Checkbox(label="Enable NSFW", value=False) | |
| ], | |
| outputs="text", | |
| title="Image to Prompt with NSFW Support", | |
| description="Convert images to prompts and generate NSFW content using the HuggingFaceH4/zephyr-7b-beta model" | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch() | |