Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| import torch.nn as nn | |
| import copy | |
| import base64 | |
| import os | |
| import requests | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| # =====================OpenAI=========================== | |
| API_KEY = os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL") | |
| if not API_KEY: | |
| raise ValueError("API_KEY is not set in the environment variables.") | |
| url = API_BASE_URL + 'chat/completions' | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': 'Bearer ' + API_KEY | |
| } | |
| My_API_KEY = os.getenv("My_API_KEY") | |
| from openai import OpenAI | |
| client = OpenAI( | |
| api_key=My_API_KEY | |
| ) | |
| class Combined_model(nn.Module): | |
| def __init__(self, model_maptype, model_location, model_century, model_note, model_area, model_topic): | |
| super(Combined_model, self).__init__() | |
| self.model_maptype = model_maptype | |
| self.model_location = model_location | |
| self.model_century = model_century | |
| self.model_note = model_note | |
| self.model_area = model_area | |
| self.model_topic = model_topic | |
| def forward(self, x): | |
| maptypes = ["topographic map", "pictorial map"] | |
| text = clip.tokenize(maptypes).to(device) | |
| logits_per_image, logits_per_text = self.model_maptype(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| maptype = maptypes[np.argmax(probs)] | |
| if maptype == "topographic map": | |
| locations = ["greece", "italy", "iberian peninsula", "france", "eastern hemisphere", "europe", | |
| "middle east", "asia minor", "germany", "british isles", "world", "egypt", "part of italy", | |
| "part of france", "part of germany", "india", "holy land", "asia", "caucasus", "sri lanka", | |
| "south america", "americas", "switzerland", "scandinavia", "netherlands", "africa", | |
| "part of greece"] | |
| text = clip.tokenize(locations).to(device) | |
| logits_per_image, logits_per_text = self.model_location(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| location = locations[np.argmax(probs)] | |
| centuries = ["19th century", "18th century", "17th century", "16th century"] | |
| text = clip.tokenize(centuries).to(device) | |
| logits_per_image, logits_per_text = self.model_century(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| century = centuries[np.argmax(probs)] | |
| notes = ["hand colored", "hand colored with decorative elements and pictorial relief", "pictorial relief", "hand colored with pictorial relief", "engraved", "decorative elements and pictorial relief"] | |
| text = clip.tokenize(notes).to(device) | |
| logits_per_image, logits_per_text = self.model_note(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| note = notes[np.argmax(probs)] | |
| return maptype, location, century, note | |
| elif maptype == "pictorial map": | |
| areas = ["united states", "world"] | |
| text = clip.tokenize(areas).to(device) | |
| logits_per_image, logits_per_text = self.model_area(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| area = areas[np.argmax(probs)] | |
| topics = ['flight network', 'news during world war 2', 'world war 2', 'transport routes', 'tourist sights', 'playing card', 'satirical representation', 'people', 'educational drawings', 'food and agriculture', 'animals', 'military', 'stamps'] | |
| text = clip.tokenize(topics).to(device) | |
| logits_per_image, logits_per_text = self.model_topic(x, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| topic = topics[np.argmax(probs)] | |
| return maptype, area, topic | |
| model_maptype = copy.deepcopy(model) | |
| model_location = copy.deepcopy(model) | |
| model_century = copy.deepcopy(model) | |
| model_note = copy.deepcopy(model) | |
| model_area = copy.deepcopy(model) | |
| model_topic = copy.deepcopy(model) | |
| def freeze_network(model): | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| return model | |
| model_path_maptype = "Models_CLIP/best_model_MapType.pt" | |
| model_maptype.load_state_dict(torch.load(model_path_maptype, map_location=device)) | |
| freeze_network(model_maptype) | |
| model_path_location = "Models_CLIP/best_model_27Countries.pt" | |
| model_location.load_state_dict(torch.load(model_path_location, map_location=device)) | |
| freeze_network(model_location) | |
| model_path_century = "Models_CLIP/best_model_Date.pt" | |
| model_century.load_state_dict(torch.load(model_path_century, map_location=device)) | |
| freeze_network(model_century) | |
| model_path_note = "Models_CLIP/best_model_Note.pt" | |
| model_note.load_state_dict(torch.load(model_path_note, map_location=device)) | |
| freeze_network(model_note) | |
| model_path_area = "Models_CLIP/best_model_Pictorial_Area.pt" | |
| model_area.load_state_dict(torch.load(model_path_area, map_location=device)) | |
| freeze_network(model_area) | |
| model_path_topic = "Models_CLIP/best_model_Pictorial_Topic_V2.pt" | |
| model_topic.load_state_dict(torch.load(model_path_topic, map_location=device)) | |
| freeze_network(model_topic) | |
| # ===================interface of GUI======================== | |
| def map_interface(map, what, where, when, why, | |
| story_35, story_4o, story_4omini, | |
| compare_4, compare_4o, compare_4omini): | |
| map = map['composite'] | |
| # Our Model | |
| if isinstance(map, str): | |
| image = preprocess(Image.open(map)).unsqueeze(0).to(device) | |
| else: | |
| image = Image.fromarray(map) | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| results = [] | |
| combined_model = Combined_model(model_maptype, model_location, model_century, model_note, model_area, model_topic) | |
| combined_model.eval() | |
| with torch.no_grad(): | |
| results = combined_model(image) | |
| question_prompt = "" | |
| questions = [] | |
| if what: | |
| questions.append("What is this map about?") | |
| if where: | |
| questions.append("Where is this map about?") | |
| if when: | |
| questions.append("When is this map about?") | |
| if why: | |
| questions.append("What could this map be used for?") | |
| # Add the additional queries to the main prompt, if any | |
| if questions: | |
| question_prompt += " Please also address the following aspects in a concise and coherent paragraph, in under 40 words, about: " + " ".join( | |
| questions) | |
| # Storytelling model prompts | |
| story_results = "" | |
| if story_35: | |
| response_our_35 = { | |
| "model": "gpt-3.5", | |
| "messages": [ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "Your response should be accurate and coherent, and use only the given keywords without adding any invented information."}, | |
| {"role": "user", | |
| "content": f"Please create a concise sentence that encapsulates these keywords: {results}.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details."}, | |
| ], | |
| "max_tokens": 100, | |
| } | |
| results_our_35 = requests.post(url, json=response_our_35, headers=headers) | |
| results_our_35 = results_our_35.json()['choices'][0]['message']['content'] | |
| results_our_35 = results_our_35.strip('"') | |
| results_our_35 = "== GPT-3.5-turbo == \n" + results_our_35 + "\n\n" | |
| story_results += results_our_35 | |
| if story_4o: | |
| response_our_4o = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "Your response should be accurate and coherent, and use only the given keywords without adding any invented information."}, | |
| {"role": "user", | |
| "content": f"Please create a concise sentence that encapsulates these keywords: {results}.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details."}, | |
| ], | |
| "max_tokens": 100, | |
| } | |
| results_our_4o = requests.post(url, json=response_our_4o, headers=headers) | |
| results_our_4o = results_our_4o.json()['choices'][0]['message']['content'] | |
| results_our_4o = results_our_4o.strip('"') | |
| results_our_4o = "== GPT-4o == \n" + results_our_4o + "\n\n" | |
| story_results += results_our_4o | |
| if story_4omini: | |
| response_our_4omini = { | |
| "model": "gpt-4o-mini", | |
| "messages": [ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "Your response should be accurate and coherent, and use only the given keywords without adding any invented information."}, | |
| {"role": "user", | |
| "content": f"Please create a concise sentence that encapsulates these keywords: {results}.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details."}, | |
| ], | |
| "max_tokens": 100, | |
| } | |
| results_our_4omini = requests.post(url, json=response_our_4omini, headers=headers) | |
| results_our_4omini = results_our_4omini.json()['choices'][0]['message']['content'] | |
| results_our_4omini = results_our_4omini.strip('"') | |
| results_our_4omini = "== GPT-4o-mini == \n" + results_our_4omini | |
| story_results += results_our_4omini | |
| # Comparison model prompts | |
| if compare_4 or compare_4o or compare_4omini: | |
| # https://cookbook.openai.com/examples/tag_caption_images_with_gpt4v | |
| # https://platform.openai.com/docs/guides/vision | |
| if not isinstance(map, str): | |
| assert False, "Type is not read as string" | |
| else: | |
| with open(map, "rb") as image_file: | |
| base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
| comparison_results = "" | |
| if compare_4: | |
| response_gpt4 = client.chat.completions.create( | |
| model="gpt-4-turbo", | |
| messages=[ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that analyzes the provided map and creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "The map caption should be accurate and coherent, and only based on information from the map."}, | |
| {"role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": f"Please succinctly caption the current map about its topic, location, time, and purpose.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details." | |
| f"Exceeding the word limit will be considered a failure of the task. Avoid lists, bullet points, or multi-paragraph responses." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| }, | |
| }, | |
| ], | |
| }, | |
| ], | |
| max_tokens=100, | |
| ) | |
| results_gpt4 = response_gpt4.choices[0].message.content | |
| results_gpt4 = results_gpt4.strip('"') | |
| results_gpt4 = "== GPT-4-turbo == \n" + results_gpt4 + "\n\n" | |
| comparison_results += results_gpt4 | |
| if compare_4o: | |
| response_gpt4o = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that analyzes the provided map and creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "The map caption should be accurate and coherent, and only based on information from the map."}, | |
| {"role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": f"Please succinctly caption the current map about its topic, location, time, and purpose.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details." | |
| f"Exceeding the word limit will be considered a failure of the task. Avoid lists, bullet points, or multi-paragraph responses." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| }, | |
| }, | |
| ], | |
| }, | |
| ], | |
| "max_tokens": 100, | |
| } | |
| results_gpt4o = requests.post(url, json=response_gpt4o, headers=headers) | |
| results_gpt4o = results_gpt4o.json()['choices'][0]['message']['content'] | |
| results_gpt4o = results_gpt4o.strip('"') | |
| results_gpt4o = "== GPT-4o == \n" + results_gpt4o + "\n\n" | |
| comparison_results += results_gpt4o | |
| if compare_4omini: | |
| response_gpt4omini = { | |
| "model": "gpt-4o-mini", | |
| "messages": [ | |
| {"role": "system", | |
| "content": "You are a helpful assistant that analyzes the provided map and creates precise, formal, and meaningful historical map descriptions in natural language paragraph." | |
| "The map caption should be accurate and coherent, and only based on information from the map."}, | |
| {"role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": f"Please succinctly caption the current map about its topic, location, time, and purpose.{question_prompt}" | |
| f"Ensure the output is a single paragraph and must strictly no longer than 50 words. Do not include any generated information or fabricated details." | |
| f"Exceeding the word limit will be considered a failure of the task. Avoid lists, bullet points, or multi-paragraph responses." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| }, | |
| }, | |
| ], | |
| }, | |
| ], | |
| "max_tokens": 100, | |
| } | |
| results_gpt4omini = requests.post(url, json=response_gpt4omini, headers=headers) | |
| results_gpt4omini = results_gpt4omini.json()['choices'][0]['message']['content'] | |
| results_gpt4omini = results_gpt4omini.strip('"') | |
| results_gpt4omini = "== GPT-4o-mini == \n" + results_gpt4omini | |
| comparison_results += results_gpt4omini | |
| else: | |
| comparison_results = "" | |
| return story_results, comparison_results | |
| def update_checkboxes(select_all, *args): | |
| # Return a tuple with the new states for the checkboxes | |
| return (select_all,) * len(args) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Demo"): | |
| with gr.Row("Map Details"): | |
| with gr.Column("Map", scale=4): | |
| image_input = gr.ImageEditor(label="Upload or Drag Map Here", type='filepath') | |
| with gr.Row("Map Details"): | |
| what = gr.Checkbox(label="What") | |
| where = gr.Checkbox(label="Where") | |
| when = gr.Checkbox(label="When") | |
| why = gr.Checkbox(label="Why") | |
| with gr.Column("Selections", scale=1): | |
| with gr.Row("Storytelling Model"): | |
| gr.Markdown("Select the GPT model for storytelling:") | |
| # add one checkbox to select all the models at once, and clear all the checkboxes when it is unchecked | |
| story_all = gr.Checkbox(label="Select/Deselect All") | |
| story_35 = gr.Checkbox(label="Storytelling using GPT-3.5-turbo") | |
| story_4o = gr.Checkbox(label="Storytelling using GPT-4o") | |
| story_4omini = gr.Checkbox(label="Storytelling using GPT-4o-mini") | |
| story_all.change(update_checkboxes, inputs=[story_all, story_35, story_4o, story_4omini], | |
| outputs=[story_35, story_4o, story_4omini]) | |
| with gr.Row("Method Comparison"): | |
| gr.Markdown("Select the GPT model for comparison:") | |
| # add one checkbox to select all the models at once | |
| compare_all = gr.Checkbox(label="Select/Deselect All") | |
| compare_4 = gr.Checkbox(label="Compare with GPT-4-turbo") | |
| compare_4o = gr.Checkbox(label="Compare with GPT-4o") | |
| compare_4omini = gr.Checkbox(label="Compare with GPT-4o-mini") | |
| compare_all.change(update_checkboxes, inputs=[compare_all, compare_4, compare_4o, compare_4omini], | |
| outputs=[compare_4, compare_4o, compare_4omini]) | |
| submit_button = gr.Button("Submit") | |
| with gr.Column("Map Captions", scale=4): | |
| with gr.Row("Our method combined with GPT-4o"): | |
| output_text_our = gr.Textbox(label="Caption generated by our method") | |
| with gr.Row("GPT-4o"): | |
| output_text_gpt = gr.Textbox(label="Caption generated by GPT models") | |
| # Define the interaction | |
| submit_button.click( | |
| fn=map_interface, | |
| inputs=[image_input, what, where, when, why, | |
| story_35, story_4o, story_4omini, | |
| compare_4, compare_4o, compare_4omini], | |
| outputs=[output_text_our, output_text_gpt] | |
| ) | |
| with gr.Tab("README"): | |
| gr.Markdown(""" | |
| # Historical Map Storytelling Tool | |
| Welcome to the Historical Map Storytelling Tool! This demo application utilizes the fine-tuned `CLIP` models and `OpenAI`'s advanced GPT models to analyze uploaded map images and generate relevant descriptions as storytelling. | |
| ## Features | |
| - **Map Type Recognition**: The app can identify the type of the historical map uploaded (e.g., topographic or pictorial). | |
| - **Map Details Extraction**: Based on the type of the map, the app will identify specific details (such as region, theme, date, etc.). | |
| - **Intelligent Description Generation**: Uses state-of-the-art GPT models to generate descriptions of the map based on identified information. | |
| - **Comparison with Vision-Enabled GPT Models**: Offers comparison with descriptions directly generated by vision-enabled GPT models. | |
| ## How to Use | |
| 1. **Upload a Map**: Upload a map image by clicking or dragging. | |
| 2. **Select Details**: Choose the map details you wish to analyze (e.g., location, time, purpose). | |
| 3. **Choose the GPT Model**: Select the GPT model you want to use for generating descriptions. | |
| 4. **Option to Compare with Vision-Enabled GPT Models**: Additionally, you can choose to compare the results with descriptions directly generated by vision-enabled GPT models. | |
| 5. **Generate Description**: Click the "Submit" button and wait for the system to process and generate a description of the map. | |
| ## Technical Background | |
| This tool combines cutting-edge technologies in image recognition and natural language processing to provide accurate historical map analysis and description generation. | |
| ## Notes | |
| - Ensure that the uploaded map is reasonably clear to facilitate recognition by the system. | |
| - Ensure the image size is not too large; large images may exceed the token limits of certain models' APIs and result in errors. | |
| - The generation of descriptions may take a few seconds to process. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |