Mavhas commited on
Commit
d1af892
·
verified ·
1 Parent(s): 6fbdde6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -62
app.py CHANGED
@@ -1,62 +1,26 @@
1
- import streamlit as st
2
- from diffusers import StableDiffusionPipeline
3
- from PIL import Image
4
- import torch
5
- import io
6
- import base64
7
-
8
- # Model loading (cached)
9
- @st.cache_resource
10
- def load_model():
11
- pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
12
- return pipe
13
-
14
- pipe = load_model()
15
-
16
- st.title("AI Image Generator")
17
-
18
- prompt = st.text_area("Enter your prompt:", height=150)
19
- num_images = st.slider("Number of Images", 1, 4, 1)
20
-
21
- if st.button("Generate"):
22
- if prompt:
23
- with st.spinner("Generating images..."):
24
- images = pipe(prompt, num_images=num_images).images
25
-
26
- for i, image in enumerate(images):
27
- col1, col2 = st.columns([1, 2])
28
-
29
- with col1:
30
- st.image(image, caption=f"Image {i+1}")
31
-
32
- with col2:
33
- buffered = io.BytesIO()
34
- image.save(buffered, format="PNG")
35
- img_str = base64.b64encode(buffered.getvalue()).decode()
36
- href = f'<a href="data:image/png;base64,{img_str}" download="image_{i+1}.png">Download Image {i+1}</a>'
37
- st.markdown(href, unsafe_allow_html=True)
38
-
39
- else:
40
- st.warning("Please enter a prompt.")
41
-
42
- # Styling (optional)
43
- st.markdown("""
44
- <style>
45
- body {
46
- font-family: sans-serif;
47
- }
48
- .stButton>button {
49
- background-color: #4CAF50;
50
- border: none;
51
- color: white;
52
- padding: 10px 20px;
53
- text-align: center;
54
- text-decoration: none;
55
- display: inline-block;
56
- font-size: 16px;
57
- margin: 4px 2px;
58
- cursor: pointer;
59
- border-radius: 5px;
60
- }
61
- </style>
62
- """, unsafe_allow_html=True)
 
1
+ import imageio
2
+ import imageio_ffmpeg
3
+ import torch
4
+ from diffusers import MochiPipeline
5
+ from diffusers.utils import export_to_video
6
+
7
+ # Load the pre-trained video generation model
8
+ model = MochiPipeline.from_pretrained(
9
+ "MISHANM/video_generation",
10
+ # variant="bf16",
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="balanced"
13
+ )
14
+
15
+ # Enable memory savings by tiling the VAE
16
+ model.enable_vae_tiling()
17
+
18
+ # Define the prompt and number of frames
19
+ prompt = "A cow drinking water on the surface of Mars."
20
+ num_frames = 20
21
+
22
+ frames = model(prompt, num_frames=num_frames).frames[0]
23
+
24
+ export_to_video(frames, "video.mp4", fps=30)
25
+
26
+ print("Video generation complete. Saved as 'video.mp4'.")