Spaces:
Runtime error
Runtime error
| import os | |
| import base64 | |
| import json | |
| import pymongo | |
| from typing import List, Optional, Dict, Any, Tuple | |
| from PIL import Image | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| import gradio as gr | |
| from pymongo import MongoClient | |
| from bson import ObjectId | |
| import asyncio | |
| from PIL import Image, ImageOps | |
| from aiohttp.client_exceptions import ClientResponseError | |
| MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017") | |
| client = MongoClient(MONGOCONN) | |
| db = client["hf-log"] # Database name | |
| collection = db["image_tagging_space"] # Collection name | |
| img_spec_token = "<|im_image|>" | |
| img_join_token = "<|and|>" | |
| sos_token = "[INST]" | |
| eos_token = "[/INST]" | |
| ASPECT_RATIOS = [(1, 1), (1, 4), (4, 1)] | |
| RESOLUTIONS = [(672, 672), (336, 1344), (1344, 336)] | |
| # Function to resize image | |
| def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str: | |
| img = Image.open(image_path) | |
| img.thumbnail((max_width, max_height), Image.LANCZOS) | |
| resized_image_path = f"/tmp/{os.path.basename(image_path)}" | |
| img.save(resized_image_path) | |
| return resized_image_path | |
| # Function to encode images to Base64 | |
| def encode_image_to_base64(image_path: str) -> str: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| # Generate prompt from images using empty tokens | |
| def img_to_prompt(images: List[str]) -> str: | |
| encoded_images = [encode_image_to_base64(img) for img in images] | |
| return img_spec_token + img_join_token.join(encoded_images) + img_spec_token | |
| # Combine image and text prompts using empty tokens | |
| def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str: | |
| system_prompt = sos_token + f"system\n{ai_role}" + eos_token | |
| user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token | |
| user_prompt += "assistant\n" | |
| return system_prompt + user_prompt | |
| def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
| return [(user_input, response) for user_input, response in history] | |
| async def call_inference(user_prompt): | |
| endpoint_url = "https://yzzwmsj8y9ji99i8.us-east-1.aws.endpoints.huggingface.cloud" | |
| llm = HuggingFaceEndpoint(endpoint_url=endpoint_url, | |
| max_new_tokens=2000, | |
| temperature=0.1, | |
| do_sample=True, | |
| use_cache=True, | |
| timeout=300) | |
| try: | |
| response = await llm._acall(user_prompt) | |
| except ClientResponseError as e: | |
| return f"API call failed: {e.message}" | |
| return response | |
| async def submit(message, history, doc_ids, last_image): | |
| # Log the user message and files | |
| print("User Message:", message["text"]) | |
| print("User Files:", message["files"]) | |
| image = None | |
| image_filetype = None | |
| if message["files"]: | |
| image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1] | |
| image_filetype = os.path.splitext(image)[1].lower() | |
| # image = resize_image(image) | |
| last_image = (image, image_filetype) | |
| else: | |
| image, image_filetype = last_image | |
| if not image: | |
| return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None) | |
| human_prompt = message['text'] | |
| img_prompt = img_to_prompt([image]) | |
| ai_role = """Your role is to validate the type of the document in the image by thoroughly examining the content and characteristics of the document.""" | |
| user_prompt = combine_img_with_text(img_prompt, | |
| human_prompt, | |
| ai_role) | |
| # Return user input immediately | |
| history.append((human_prompt, "<processing>")) | |
| outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False) | |
| # Call inference asynchronously | |
| response = await call_inference(user_prompt) | |
| selected_output = response.split("assistant\n")[-1].strip() | |
| # Store the message, image prompt, response, and image file type in MongoDB | |
| document = { | |
| 'image_prompt': img_prompt, | |
| 'user_prompt': human_prompt, | |
| 'response': selected_output, | |
| 'image_filetype': image_filetype, | |
| 'likes': 0, | |
| 'dislikes': 0, | |
| 'like_dislike_reason': None | |
| } | |
| result = collection.insert_one(document) | |
| document_id = str(result.inserted_id) | |
| # Log the storage in MongoDB | |
| print(f"Stored in MongoDB with ID: {document_id}") | |
| # Update the chat history and document IDs | |
| history[-1] = (human_prompt, selected_output) | |
| doc_ids.append(document_id) | |
| return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False) | |
| def print_like_dislike(x: gr.LikeData, history, doc_ids, reason): | |
| if not history: | |
| return | |
| index = x.index[0] if isinstance(x.index, list) else x.index | |
| document_id = doc_ids[index] | |
| update_field = "likes" if x.liked else "dislikes" | |
| collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}}) | |
| print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}") | |
| def submit_reason_only(doc_ids, reason, selected_index, history): | |
| if selected_index is None: | |
| selected_index = len(history) - 1 # Select the last message if no message is selected | |
| document_id = doc_ids[selected_index] | |
| collection.update_one( | |
| {"_id": ObjectId(document_id)}, | |
| {"$set": {"like_dislike_reason": reason}} | |
| ) | |
| print(f"Document ID: {document_id}, Reason submitted: {reason}") | |
| return f"Reason submitted." | |
| PLACEHOLDER = """ | |
| <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
| <img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;"> | |
| <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-NeXT-Mistral-7B-LFX</h1> | |
| <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is hosted by LFX</p> | |
| </div> | |
| """ | |
| with gr.Blocks(fill_height=True) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600) | |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
| with gr.Column(scale=1): | |
| image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400) | |
| reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True) | |
| submit_reason_btn = gr.Button("Submit Reason", visible=True) | |
| history_state = gr.State([]) | |
| doc_ids_state = gr.State([]) | |
| last_image_state = gr.State((None, None)) | |
| selected_index_state = gr.State(None) # Initializing the state | |
| def select_message(evt: gr.SelectData, history, doc_ids): | |
| selected_index = evt.index if isinstance(evt.index, int) else evt.index[0] | |
| print(f"Selected Index: {selected_index}") # Debugging print statement | |
| return gr.update(visible=True), selected_index | |
| chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display]) | |
| chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[]) | |
| chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state | |
| submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False, share=True, debug=True) |