Spaces:
Sleeping
Sleeping
| import requests | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| _UINT8_MAX_F = float(np.iinfo(np.uint8).max) | |
| def load_image(img_url: str): | |
| """Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32.""" | |
| if (img_url.startswith("https")): | |
| response = requests.get(img_url) | |
| image_data = response.content | |
| else: | |
| image_data = tf.io.read_file(img_url) | |
| image = tf.io.decode_image(image_data, channels=3) | |
| image_numpy = tf.cast(image, dtype=tf.float32).numpy() | |
| return image_numpy / _UINT8_MAX_F | |
| model = hub.load('https://www.kaggle.com/models/google/film/TensorFlow2/film/1') | |
| def interpolate_batch(img1, img2, batch_size=30): | |
| if model is None: | |
| print("Model not loaded. Exiting interpolation.") | |
| return None | |
| # Generate time steps from 0 to 1, rounded to two decimal places | |
| time_values = np.linspace(0.0, 1.0, batch_size, dtype=np.float32) | |
| time_values = np.round(time_values, 2) # Round time values to two decimal places | |
| time_values = np.expand_dims(time_values, axis=0) # Add batch dimension | |
| image1 = load_image(img1) | |
| image2 = load_image(img2) | |
| if image1 is None or image2 is None: | |
| print("One or both images failed to load. Exiting interpolation.") | |
| return None | |
| # Create batch input for model | |
| input_data = { | |
| 'time': np.repeat(time_values, image1.shape[0], axis=0), # Expand time across batch | |
| 'x0': np.repeat(np.expand_dims(image1, axis=0), batch_size, axis=0), # Repeat image1 for each batch | |
| 'x1': np.repeat(np.expand_dims(image2, axis=0), batch_size, axis=0) # Repeat image2 for each batch | |
| } | |
| try: | |
| mid_frames = model(input_data) | |
| frames = mid_frames['image'].numpy() # Get interpolated frames | |
| return frames | |
| except Exception as e: | |
| print(f"Error during interpolation: {e}") | |
| return None | |
| def interpolate_single(img1, img2): | |
| """Interpolate a single frame at the midpoint between two images (time=0.5).""" | |
| if model is None: | |
| print("Model not loaded. Exiting interpolation.") | |
| return None | |
| # Midpoint time value with batch size of 1 | |
| time_value = np.array([[0.5]], dtype=np.float32) # shape [1, 1] | |
| # Load and normalize images | |
| image1 = load_image(img1) | |
| image2 = load_image(img2) | |
| # Ensure the images have the same height and width if needed | |
| target_height, target_width = image1.shape[0], image1.shape[1] | |
| image1_resized = tf.image.resize(image1, [target_height, target_width]).numpy() | |
| image2_resized = tf.image.resize(image2, [target_height, target_width]).numpy() | |
| # Expand dimensions to add batch dimension | |
| input_data = { | |
| 'time': time_value, # shape [1, 1] | |
| 'x0': np.expand_dims(image1_resized, axis=0), # shape [1, H, W, 3] | |
| 'x1': np.expand_dims(image2_resized, axis=0), # shape [1, H, W, 3] | |
| } | |
| try: | |
| mid_frame = model(input_data) | |
| frame = mid_frame['image'][0].numpy() # Extract the interpolated frame | |
| return frame | |
| except Exception as e: | |
| print(f"Error during interpolation: {e}") | |
| return None |