Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
import imageio
|
| 9 |
+
import tempfile
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Load ASR model (open-source, not Whisper)
|
| 13 |
+
asr_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
| 14 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
| 15 |
+
|
| 16 |
+
# Load image captioning model
|
| 17 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 18 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 19 |
+
|
| 20 |
+
# Load Zeroscope video generation pipeline
|
| 21 |
+
video_pipe = StableDiffusionPipeline.from_pretrained(
|
| 22 |
+
"cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 23 |
+
).to("cuda")
|
| 24 |
+
video_pipe.scheduler = DDIMScheduler.from_config(video_pipe.scheduler.config)
|
| 25 |
+
|
| 26 |
+
# --- Helper functions ---
|
| 27 |
+
def transcribe_audio(audio_path):
|
| 28 |
+
waveform, rate = torchaudio.load(audio_path)
|
| 29 |
+
input_values = asr_processor(waveform[0], sampling_rate=rate, return_tensors="pt").input_values
|
| 30 |
+
logits = asr_model(input_values).logits
|
| 31 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 32 |
+
transcription = asr_processor.decode(predicted_ids[0])
|
| 33 |
+
return transcription
|
| 34 |
+
|
| 35 |
+
def describe_image(image):
|
| 36 |
+
inputs = blip_processor(image, return_tensors="pt")
|
| 37 |
+
out = blip_model.generate(**inputs)
|
| 38 |
+
description = blip_processor.decode(out[0], skip_special_tokens=True)
|
| 39 |
+
return description
|
| 40 |
+
|
| 41 |
+
def build_prompt(image_desc, voice_text, influencer_task):
|
| 42 |
+
return f"A cinematic video of {image_desc}. They are speaking about '{voice_text}'. Their daily routine: {influencer_task}."
|
| 43 |
+
|
| 44 |
+
def generate_video(prompt, job_id):
|
| 45 |
+
frames = video_pipe(prompt, num_inference_steps=25, height=512, width=768, num_frames=24).frames
|
| 46 |
+
temp_video_path = os.path.join(tempfile.gettempdir(), f"{job_id}_output.mp4")
|
| 47 |
+
imageio.mimsave(temp_video_path, [np.array(f) for f in frames], fps=8)
|
| 48 |
+
return temp_video_path
|
| 49 |
+
|
| 50 |
+
# --- Gradio interface function ---
|
| 51 |
+
def process_inputs(user_image, voice, influencer_tasks, job_id):
|
| 52 |
+
image_desc = describe_image(user_image)
|
| 53 |
+
voice_text = transcribe_audio(voice)
|
| 54 |
+
final_prompt = build_prompt(image_desc, voice_text, influencer_tasks)
|
| 55 |
+
video_path = generate_video(final_prompt, job_id)
|
| 56 |
+
return video_path, final_prompt
|
| 57 |
+
|
| 58 |
+
# --- Gradio UI ---
|
| 59 |
+
with gr.Blocks() as demo:
|
| 60 |
+
gr.Markdown("# π§βπ€ Influencer Video Generator")
|
| 61 |
+
|
| 62 |
+
with gr.Row():
|
| 63 |
+
with gr.Column():
|
| 64 |
+
user_image = gr.Image(label="Upload Your Image", type="pil")
|
| 65 |
+
voice = gr.Audio(source="upload", label="Upload Your Voice (WAV/MP3)", type="filepath")
|
| 66 |
+
influencer_tasks = gr.Textbox(label="What does the influencer do daily?", placeholder="e.g., go to gym, film reels, drink coffee")
|
| 67 |
+
job_id = gr.Textbox(label="Job ID", placeholder="e.g., JOB-12345")
|
| 68 |
+
generate_btn = gr.Button("π₯ Generate Video")
|
| 69 |
+
with gr.Column():
|
| 70 |
+
output_video = gr.Video(label="Generated Video")
|
| 71 |
+
prompt_display = gr.Textbox(label="Generated Prompt")
|
| 72 |
+
|
| 73 |
+
generate_btn.click(fn=process_inputs,
|
| 74 |
+
inputs=[user_image, voice, influencer_tasks, job_id],
|
| 75 |
+
outputs=[output_video, prompt_display])
|
| 76 |
+
|
| 77 |
+
demo.launch()
|