Spaces:
Runtime error
Runtime error
| import os | |
| import string | |
| import copy | |
| import gradio as gr | |
| import PIL.Image | |
| import torch | |
| from transformers import BitsAndBytesConfig, pipeline | |
| import re | |
| import time | |
| import random | |
| DESCRIPTION = "# LLaVA ππͺ - Now with Arnold Mode and Bodybuilding Coaching Expertise!" | |
| model_id = "llava-hf/llava-1.5-7b-hf" | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
| bodybuilding_criteria = { | |
| "Muscular Size": "Focus on overall muscle mass and development.", | |
| "Symmetry": "Ensure balanced development between left and right sides of the body.", | |
| "Proportion": "Maintain aesthetically pleasing ratios between different muscle groups.", | |
| "Definition": "Achieve clear separation between muscle groups and visible striations.", | |
| "Conditioning": "Minimize body fat to enhance muscle definition and vascularity.", | |
| "Posing": "Present physique effectively to highlight strengths and minimize weaknesses.", | |
| } | |
| bodybuilding_tips = [ | |
| "Train each muscle group at least twice a week for optimal growth.", | |
| "Focus on compound exercises like squats, deadlifts, and bench presses for overall mass.", | |
| "Don't neglect your legs! They're half your physique.", | |
| "Proper nutrition is key. Eat clean and maintain a caloric surplus for growth.", | |
| "Get enough rest. Muscles grow during recovery, not in the gym.", | |
| "Practice your posing regularly. It's not just for shows, it helps mind-muscle connection.", | |
| "Stay hydrated. Water is crucial for muscle function and recovery.", | |
| ] | |
| def extract_response_pairs(text): | |
| turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] | |
| turns = [turn.strip() for turn in turns if turn.strip()] | |
| conv_list = [] | |
| for i in range(0, len(turns[1::2]), 2): | |
| if i + 1 < len(turns[1::2]): | |
| conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) | |
| return conv_list | |
| def add_text(history, text): | |
| history = history + [[text, None]] | |
| return history, "" # Clear the input field after submission | |
| def arnold_speak(text): | |
| arnold_phrases = [ | |
| "Come with me if you want to lift!", | |
| "I'll be back... after my protein shake.", | |
| "Hasta la vista, baby weight!", | |
| "Get to da choppa... I mean, da squat rack!", | |
| "You lack discipline! But don't worry, I'm here to pump you up!" | |
| ] | |
| text = text.replace(".", "!") # More enthusiastic punctuation | |
| text = text.replace("gym", "iron paradise") | |
| text = text.replace("exercise", "pump iron") | |
| text = text.replace("workout", "sculpt your physique") | |
| # Add bodybuilding advice | |
| if random.random() < 0.7: # 70% chance to add bodybuilding advice | |
| advice = random.choice(list(bodybuilding_criteria.items())) | |
| text += f" Remember, in bodybuilding, {advice[0]} is crucial! {advice[1]}" | |
| # Add a bodybuilding tip | |
| if random.random() < 0.5: # 50% chance to add a tip | |
| tip = random.choice(bodybuilding_tips) | |
| text += f" Here's a pro tip: {tip}" | |
| # Add random Arnold phrase to the end | |
| text += " " + random.choice(arnold_phrases) | |
| return text | |
| def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): | |
| try: | |
| outputs = pipe(images=image, prompt=prompt, | |
| generate_kwargs={"temperature": temperature, | |
| "length_penalty": length_penalty, | |
| "repetition_penalty": repetition_penalty, | |
| "max_length": max_length, | |
| "min_length": min_length, | |
| "top_p": top_p}) | |
| inference_output = outputs[0]["generated_text"] | |
| return inference_output | |
| except Exception as e: | |
| print(f"Error during inference: {str(e)}") | |
| return f"An error occurred during inference: {str(e)}" | |
| def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode): | |
| if text_input == "": | |
| yield history + [["Please input text", None]] | |
| return | |
| if image is None: | |
| yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]] | |
| return | |
| if arnold_mode: | |
| system_prompt = """You are Arnold Schwarzenegger, the famous bodybuilder, actor, and former Mr. Olympia. | |
| Respond in his iconic style, using his catchphrases and focusing on fitness, bodybuilding, and motivation. | |
| Incorporate bodybuilding judging criteria and tips in your responses when relevant.""" | |
| else: | |
| system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input." | |
| # Use only the current input for generating the response | |
| prompt = f"{system_prompt}\nUSER: <image>\n{text_input}\nASSISTANT:" | |
| response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) | |
| if arnold_mode: | |
| response = arnold_speak(response) | |
| history.append([text_input, ""]) | |
| for i in range(len(response)): | |
| history[-1][1] = response[:i+1] | |
| time.sleep(0.05) | |
| yield history | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! β‘οΈ | |
| See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""") | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| image = gr.Image(type="pil") | |
| with gr.Column(): | |
| text_input = gr.Textbox(label="Chat Input", lines=3) | |
| arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1) | |
| length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2) | |
| repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5) | |
| max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1) | |
| min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1) | |
| top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1) | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear") | |
| submit_button = gr.Button("Submit", variant="primary") | |
| submit_button.click( | |
| fn=bot, | |
| inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode], | |
| outputs=chatbot | |
| ).then( | |
| fn=lambda: "", | |
| outputs=text_input | |
| ) | |
| clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False) | |
| examples = [ | |
| ["./examples/bodybuilder.jpeg", "What do you think of this physique?"], | |
| ["./examples/gym.jpeg", "How can I improve my workout routine?"] | |
| ] | |
| gr.Examples(examples=examples, inputs=[image, text_input]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(debug=True) |