Spaces:
Sleeping
Sleeping
| import base64 | |
| import requests | |
| import json | |
| import pandas as pd | |
| import os | |
| from tqdm import tqdm | |
| import re | |
| import torch | |
| import io | |
| from PIL import Image | |
| def image_to_bytes(image): | |
| """Convert PIL Image to bytes.""" | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="JPEG") # Adjust format if necessary | |
| return buffer.getvalue() | |
| def query_clip(data, hf_token): | |
| API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| img = data['image'] | |
| img_bytes = image_to_bytes(img) | |
| image = Image.open(io.BytesIO(img_bytes)) | |
| encoded_img = base64.b64encode(img_bytes).decode("utf-8") | |
| payload={ | |
| "parameters": data["parameters"], | |
| "inputs": encoded_img | |
| } | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| def get_sentiment(img, hf_token): | |
| print("Getting the sentiment of the image...") | |
| output = query_clip({ | |
| "image": img, | |
| "parameters": {"candidate_labels": ["angry", "happy"]}, | |
| }, hf_token) | |
| try: | |
| print("Sentiment:", output[0]['label']) | |
| return output[0]['label'] | |
| except: | |
| print(output) | |
| print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.") | |
| def query_blip(img, hf_token): | |
| API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| img_bytes = image_to_bytes(img) | |
| files = { | |
| 'file': ('image.jpg', img_bytes, 'image/jpeg') | |
| } | |
| response = requests.post(API_URL, headers=headers, data=files) | |
| return response.json() | |
| def get_description(img, hf_token): | |
| print("Getting the context of the image...") | |
| output = query_blip(img, hf_token) | |
| try: | |
| print("Context:", output[0]['generated_text']) | |
| return output[0]['generated_text'] | |
| except: | |
| print(output) | |
| print("The model is not available right now due to query limits. Try running again now or within the next hour") | |
| def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'): | |
| sentiment = get_sentiment(img_path, hf_token) | |
| description = get_description(img_path, hf_token) | |
| prompt_template = """ | |
| Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n | |
| You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more. | |
| Topic: {query} | |
| <end_of_turn>\\n<start_of_turn>model Caption: | |
| """ | |
| prompt = prompt_template.format(query=description) | |
| print("Generating captions...") | |
| encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | |
| model_inputs = encodeds.to(device) | |
| print("sentiment", sentiment) | |
| base_model.set_adapter(sentiment) | |
| base_model.to(device) | |
| generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id) | |
| decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| return (decoded) |