Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel | |
| import zipfile | |
| import os | |
| import csv | |
| from PIL import Image | |
| device = 'cpu' | |
| model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05" | |
| # Load the pretrained model, feature extractor, and tokenizer | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device) | |
| feature_extractor = ViTImageProcessor.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def predict(image, max_length=64, num_beams=4): | |
| # Process the input image | |
| image = image.convert('RGB') | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) | |
| # Generate the caption | |
| caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0] | |
| # Decode and clean the generated caption | |
| caption = tokenizer.decode(caption_ids, skip_special_tokens=True) | |
| return caption | |
| def process_images(image_files): | |
| captions = [] | |
| for image_file in image_files: | |
| try: | |
| # Open and verify the image | |
| with Image.open(image_file) as img: | |
| caption = predict(img) | |
| captions.append((os.path.basename(image_file), caption)) | |
| except Exception as e: | |
| print(f"Skipping file {image_file}: {e}") | |
| # Save the results to a CSV file | |
| csv_file_path = 'image_captions.csv' | |
| with open(csv_file_path, mode='w', newline='') as file: | |
| writer = csv.writer(file) | |
| writer.writerow(['Image', 'Caption']) | |
| writer.writerows(captions) | |
| return csv_file_path | |
| def process_zip_files(zip_file_paths): | |
| # Create a directory to extract images | |
| extract_dir = 'extracted_images' | |
| os.makedirs(extract_dir, exist_ok=True) | |
| captions = [] | |
| for zip_file_path in zip_file_paths: | |
| with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_dir) | |
| # Verify extracted files and process images | |
| for root, dirs, files in os.walk(extract_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| try: | |
| # Open and verify the image | |
| with Image.open(file_path) as img: | |
| caption = predict(img) | |
| captions.append((file, caption)) | |
| except Exception as e: | |
| print(f"Skipping file {file}: {e}") | |
| # Save the results to a CSV file | |
| csv_file_path = 'zip_image_captions.csv' | |
| with open(csv_file_path, mode='w', newline='') as file: | |
| writer = csv.writer(file) | |
| writer.writerow(['Image Name', 'Caption']) | |
| writer.writerows(captions) | |
| return csv_file_path | |
| def gr_process(zip_files, image_files): | |
| if not zip_files and not image_files: | |
| raise ValueError("At least one of zip_files or image_files must be provided.") | |
| elif zip_files: | |
| zip_file_paths = [zip_file.name for zip_file in zip_files] | |
| return process_zip_files(zip_file_paths) | |
| elif image_files: | |
| image_file_paths = [image_file.name for image_file in image_files] | |
| return process_images(image_file_paths) | |
| def combine_csv_files(file1, file2, output_file='combined_captions.csv'): | |
| with open(output_file, mode='w', newline='') as outfile: | |
| writer = csv.writer(outfile) | |
| writer.writerow(['Image Name', 'Caption']) | |
| for file in [file1, file2]: | |
| if os.path.exists(file): | |
| with open(file, mode='r') as infile: | |
| reader = csv.reader(infile) | |
| next(reader) # Skip header row | |
| for row in reader: | |
| writer.writerow(row) | |
| return output_file | |
| css = ''' | |
| h1#title { | |
| text-align: center; | |
| } | |
| h3#header { | |
| text-align: center; | |
| } | |
| img#overview { | |
| max-width: 800px; | |
| max-height: 600px; | |
| } | |
| img#style-image { | |
| max-width: 1000px; | |
| max-height: 600px; | |
| } | |
| .gr-image { | |
| max-width: 150px; /* Set a small box for the image */ | |
| max-height: 150px; | |
| } | |
| ''' | |
| demo = gr.Blocks(css=css) | |
| with demo: | |
| gr.Markdown('''<h1 id="title">Image Caption 🖼️</h1>''') | |
| gr.Markdown('''Made by : No. Fa.''') | |
| zip_files = gr.State([]) | |
| image_files = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| new_zip_files = gr.File(label="Upload Zip Files", type="filepath", file_count="multiple") | |
| generate_zip_captions_btn = gr.Button("Generate Zip Captions") | |
| new_image_files = gr.File(label="Upload Images", type="filepath", file_count="multiple") | |
| generate_image_captions_btn = gr.Button("Generate Image Captions") | |
| with gr.Column(scale=3): | |
| output_zip_file = gr.File(label="Download Zip Captions") | |
| output_image_file = gr.File(label="Download Image Captions") | |
| combined_file = gr.File(label="Download Combined Captions") | |
| combine_files_btn = gr.Button("Combine CSV Files") | |
| generate_zip_captions_btn.click(fn=gr_process, inputs=new_zip_files, outputs=output_zip_file) | |
| generate_image_captions_btn.click(fn=gr_process, inputs=image_files, outputs=output_image_file) | |
| combine_files_btn.click(fn=combine_csv_files, inputs=[output_zip_file, output_image_file], outputs=combined_file) | |
| demo.launch() | |