Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| from dotenv import load_dotenv,find_dotenv | |
| load_dotenv(find_dotenv()) | |
| import requests | |
| from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import numpy as np | |
| token=os.getenv('HF_TOKEN') | |
| # Initialize models | |
| # OCR model for text extraction | |
| ocr_model = pipeline("document-question-answering", model="impira/layoutlm-document-qa") | |
| # Florence-2 model for image understanding | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) | |
| florence_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Florence-2-base", | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True | |
| ).to(device) | |
| # LLaMA model for game control interface reasoning | |
| llm_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| llm_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").to(device) | |
| def preprocess_image(image): | |
| """Convert image to the format required by the models""" | |
| if isinstance(image, str): # If image is a base64 string | |
| image = Image.open(io.BytesIO(base64.b64decode(image.split(",")[1]))) | |
| return image | |
| def extract_text_from_image(image): | |
| """Extract text from the image using OCR and Florence-2's OCR capabilities""" | |
| image = preprocess_image(image) | |
| # Use LayoutLM for document text extraction | |
| layout_result = ocr_model(image=image, question="What text is in this image?") | |
| layout_text = layout_result['answer'] | |
| # Also use Florence-2 for text detection | |
| # Florence-2 can be used with <OCR> task | |
| prompt = "<OCR>" | |
| # Process with Florence-2 for OCR | |
| inputs = florence_processor( | |
| text=prompt, | |
| images=image, | |
| return_tensors="pt" | |
| ).to(device, torch_dtype) | |
| with torch.no_grad(): | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| # Decode and process the generated text | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_text = florence_processor.post_process_generation(generated_text, task="<OCR>") | |
| # Combine results from both models | |
| combined_text = f"LayoutLM OCR: {layout_text}\n\nFlorence-2 OCR: {parsed_text}" | |
| return combined_text | |
| def analyze_image(image): | |
| """Analyze image content using Florence-2 for object detection""" | |
| image = preprocess_image(image) | |
| # Use Object Detection task with Florence-2 | |
| prompt = "<OD>" # Object Detection task token | |
| # Process the image with Florence-2 | |
| inputs = florence_processor( | |
| text=prompt, | |
| images=image, | |
| return_tensors="pt" | |
| ).to(device, torch_dtype) | |
| with torch.no_grad(): | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| # Decode and post-process the generated text | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_objects = florence_processor.post_process_generation( | |
| generated_text, | |
| task="<OD>", | |
| image_size=(image.width, image.height) | |
| ) | |
| # Process the detected objects for game analysis | |
| game_related_categories = [ | |
| "person", "player", "character", "enemy", "button", "screen", "display", | |
| "health", "bar", "score", "menu", "weapon", "item", "obstacle", "platform", | |
| "text", "number", "icon", "power-up", "door", "key", "coin", "vehicle" | |
| ] | |
| # Filter and organize detected objects | |
| detected_elements = {} | |
| confidence_sum = 0 | |
| count = 0 | |
| for obj in parsed_objects: | |
| category = obj["category"] | |
| confidence = obj["score"] | |
| # Check if this is a game-related object or try to map it | |
| for game_category in game_related_categories: | |
| if game_category in category.lower(): | |
| if category not in detected_elements or confidence > detected_elements[category]["confidence"]: | |
| detected_elements[category] = { | |
| "confidence": confidence, | |
| "box": obj["box"] # Keep the bounding box information | |
| } | |
| confidence_sum += confidence | |
| count += 1 | |
| break | |
| # Calculate average confidence | |
| avg_confidence = confidence_sum / max(count, 1) | |
| return { | |
| "detected_elements": list(detected_elements.keys()), | |
| "element_details": detected_elements, | |
| "confidence": avg_confidence | |
| } | |
| def generate_game_control(text_content, image_analysis, user_input): | |
| """Generate game control interface suggestions using LLaMA and Florence-2's visual understanding""" | |
| # Extract more detailed information from the image analysis | |
| detected_elements = image_analysis['detected_elements'] | |
| element_details = image_analysis['element_details'] | |
| # Create a more detailed prompt for LLaMA with positional information | |
| detailed_elements = [] | |
| for element in detected_elements: | |
| if element in element_details: | |
| box = element_details[element]['box'] | |
| confidence = element_details[element]['confidence'] | |
| position = f"at position x:{box[0]:.1f}-{box[2]:.1f}, y:{box[1]:.1f}-{box[3]:.1f}" | |
| detailed_elements.append(f"{element} ({position}, confidence: {confidence:.2f})") | |
| # Format detailed elements text | |
| detailed_elements_text = "\n - ".join([""] + detailed_elements) if detailed_elements else "None detected with high confidence" | |
| # Prepare comprehensive prompt for LLaMA | |
| prompt = f""" | |
| You are an AI game assistant that helps players understand game screenshots and provides control suggestions. | |
| Game screenshot analysis: | |
| - Text content detected: | |
| {text_content} | |
| - Visual elements detected: {detailed_elements_text} | |
| - Overall detection confidence: {image_analysis['confidence']:.2f} | |
| User query: {user_input} | |
| Based on the game screenshot analysis above, provide specific game control suggestions. | |
| Focus on: | |
| 1. What UI elements the player should interact with | |
| 2. Which buttons or controls they should use | |
| 3. Gameplay strategy based on what's visible | |
| 4. Clear next steps or actions | |
| Your response: | |
| """ | |
| # Process with LLaMA | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| num_beams=3, | |
| ) | |
| response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract the game control suggestions from the response | |
| try: | |
| suggestions = response.split("Your response:")[1].strip() | |
| except: | |
| suggestions = response | |
| return suggestions | |
| def process_game_screenshot(image, user_input): | |
| """Main function to process game screenshot and generate control interface""" | |
| if image is None: | |
| return "Please upload a game screenshot." | |
| # Extract text from image | |
| text_content = extract_text_from_image(image) | |
| # Analyze image content | |
| image_analysis = analyze_image(image) | |
| # Use Florence-2 for a general image description as well | |
| # This gives additional context about the game scene | |
| image_desc = get_florence_image_description(image) | |
| # Generate game control interface suggestions | |
| control_suggestions = generate_game_control(text_content, image_analysis, user_input) | |
| # Create comprehensive response | |
| detected_elements_formatted = [] | |
| for elem in image_analysis['detected_elements']: | |
| if elem in image_analysis['element_details']: | |
| conf = image_analysis['element_details'][elem]['confidence'] | |
| detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})") | |
| elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence" | |
| response = f""" | |
| ## Game Screenshot Analysis | |
| ### Scene Description: | |
| {image_desc} | |
| ### Text Content Detected: | |
| {text_content} | |
| ### Visual Elements Detected: | |
| {elements_text} | |
| ## Game Control Suggestions: | |
| {control_suggestions} | |
| """ | |
| return response | |
| def get_florence_image_description(image): | |
| """Get a general description of the image using Florence-2's image captioning capability""" | |
| image = preprocess_image(image) | |
| # Use Image Captioning task with Florence-2 | |
| prompt = "<IC>" # Image Captioning task token | |
| # Process the image with Florence-2 | |
| inputs = florence_processor( | |
| text=prompt, | |
| images=image, | |
| return_tensors="pt" | |
| ).to(device, torch_dtype) | |
| with torch.no_grad(): | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.9, | |
| num_beams=3, | |
| ) | |
| # Decode the generated text | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| caption = florence_processor.post_process_generation(generated_text, task="<IC>") | |
| return caption | |
| def create_api(): | |
| """Create and expose the API endpoint""" | |
| with gr.Blocks(title="Game Control Interface AI") as app: | |
| gr.Markdown("# Game Control Interface AI") | |
| gr.Markdown("Upload a game screenshot and provide your query to get game control suggestions") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Game Screenshot") | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Your Query", | |
| placeholder="e.g., 'How do I defeat this enemy?', 'What should I do next?'", | |
| lines=3 | |
| ) | |
| submit_button = gr.Button("Analyze Screenshot", variant="primary") | |
| # Add example queries to help users | |
| example_queries = [ | |
| ["What should I do next in this game?"], | |
| ["How do I defeat this enemy?"], | |
| ["What items should I collect in this scene?"], | |
| ["How do I solve this puzzle?"], | |
| ["What controls should I use in this situation?"] | |
| ] | |
| gr.Examples( | |
| examples=example_queries, | |
| inputs=text_input | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Add tabs for different views | |
| with gr.Tabs(): | |
| with gr.TabItem("Game Control Suggestions"): | |
| output = gr.Markdown(label="Game Control Interface Suggestions") | |
| with gr.TabItem("Raw Analysis Data"): | |
| with gr.Accordion("OCR Results", open=False): | |
| ocr_output = gr.Textbox(label="Text Detection Results", lines=5) | |
| with gr.Accordion("Object Detection", open=False): | |
| object_output = gr.JSON(label="Detected Objects") | |
| # Define processing function with multiple outputs | |
| def process_with_details(image, user_input): | |
| if image is None: | |
| return "Please upload a game screenshot.", "No text detected", {} | |
| # Extract text from image | |
| text_content = extract_text_from_image(image) | |
| # Analyze image content | |
| image_analysis = analyze_image(image) | |
| # Use Florence-2 for a general image description | |
| image_desc = get_florence_image_description(image) | |
| # Generate game control interface suggestions | |
| control_suggestions = generate_game_control(text_content, image_analysis, user_input) | |
| # Format main response | |
| detected_elements_formatted = [] | |
| for elem in image_analysis['detected_elements']: | |
| if elem in image_analysis['element_details']: | |
| conf = image_analysis['element_details'][elem]['confidence'] | |
| detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})") | |
| elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence" | |
| response = f""" | |
| ## Game Screenshot Analysis | |
| ### Scene Description: | |
| {image_desc} | |
| ### Text Content Detected: | |
| {text_content} | |
| ### Visual Elements Detected: | |
| {elements_text} | |
| ## Game Control Suggestions: | |
| {control_suggestions} | |
| """ | |
| return response, text_content, image_analysis['element_details'] | |
| # Connect the interface | |
| submit_button.click( | |
| fn=process_with_details, | |
| inputs=[image_input, text_input], | |
| outputs=[output, ocr_output, object_output] | |
| ) | |
| # Add API endpoint | |
| gr.Interface( | |
| fn=process_game_screenshot, | |
| inputs=[ | |
| gr.Image(type="pil", label="Game Screenshot"), | |
| gr.Textbox(label="User Query") | |
| ], | |
| outputs=gr.Markdown(label="Game Control Interface Suggestions"), | |
| title="Game Control Interface AI API", | |
| description="API for game screenshot analysis and control suggestions", | |
| examples=[ | |
| ["path/to/example_screenshot.jpg", "What should I do next?"], | |
| ["path/to/example_boss_battle.jpg", "How do I defeat this boss?"] | |
| ] | |
| ).launch(share=True) | |
| return app | |
| # Entry point | |
| if __name__ == "__main__": | |
| app = create_api() | |
| app.launch(share=True, show_api=True) |