TDN-M commited on
Commit
08bfaed
·
verified ·
1 Parent(s): 7ba6276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -124
app.py CHANGED
@@ -1,135 +1,95 @@
1
- import os
2
- import gradio as gr
3
- from groq import Groq
4
- from huggingface_hub import InferenceClient
5
- from moviepy.editor import VideoFileClip
6
  import numpy as np
7
- from io import BytesIO
 
 
8
 
9
- # Lưu khóa API
10
- if 'GROQ_API_KEY' not in os.environ:
11
- os.environ['GROQ_API_KEY'] = input('Nhập khóa API Groq của bạn: ')
12
- if 'HF_TOKEN' not in os.environ:
13
- os.environ['HF_TOKEN'] = input('Nhập Hugging Face API Token của bạn: ')
14
 
15
- MAX_VIDEO_SIZE_MB = 35
 
 
 
 
 
 
 
 
 
16
 
17
- # Hàm gọi Groq API
18
- def call_groq_api(prompt, model_name="meta-llama/llama-4-scout-17b-16e-instruct", max_tokens=2048):
19
- client = Groq(api_key=os.environ.get('GROQ_API_KEY'))
20
- response = client.chat.completions.create(
21
- model=model_name,
22
- messages=prompt,
23
- max_tokens=max_tokens
24
- )
25
- return response.choices[0].message.content
 
26
 
27
- # Hàm trích xuất audio từ video
28
- def extract_audio_from_video(video_path):
29
- try:
30
- video = VideoFileClip(video_path)
31
- audio_path = "temp_audio.mp3"
32
- video.audio.write_audiofile(audio_path)
33
- video.close()
34
- return audio_path
35
- except Exception as e:
36
- raise Exception(f"Lỗi khi trích xuất audio: {str(e)}")
37
 
38
- # Hàm trích xuất thumbnails từ video
39
- def extract_thumbnails(video_path, num_frames=6):
40
- try:
41
- video = VideoFileClip(video_path)
42
- duration = video.duration
43
- step = duration / num_frames
44
- thumbnails = []
45
- for i in range(num_frames):
46
- frame_time = i * step
47
- frame = video.get_frame(frame_time)
48
- thumbnails.append(frame)
49
- video.close()
50
- return thumbnails
51
- except Exception as e:
52
- raise Exception(f"Lỗi khi trích xuất thumbnails: {str(e)}")
53
 
54
- # Hàm gọi Hugging Face Inference API để chuyển audio thành văn bản
55
- def transcribe_audio(audio_path):
56
- try:
57
- client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
58
- with open(audio_path, "rb") as audio_file:
59
- transcription = client.automatic_speech_recognition(
60
- audio_file,
61
- model="openai/whisper-tiny"
62
- )
63
- return transcription['text']
64
- except Exception as e:
65
- raise Exception(f"Lỗi khi chuyển audio thành văn bản: {str(e)}")
66
 
67
- # Hàm tạo prompt cho Groq
68
- def create_prompt(social_media_type, transcription_text):
69
- system_msg = (
70
- "Bạn là chuyên gia trong việc tạo nội dung mạng xã hội và tạo bài đăng hiệu quả dựa trên nội dung người dùng. "
71
- "Tuân thủ quy tắc và ràng buộc của nền tảng mạng xã hội."
72
- )
73
- user_msg = f"Nền tảng: {social_media_type}\nNội dung văn bản: {transcription_text}"
74
- return [
75
- {"role": "system", "content": system_msg},
76
- {"role": "user", "content": user_msg}
77
- ]
78
 
79
- # Hàm xử lý chính
80
- def process_and_generate_post(video_file, social_media_type, progress=gr.Progress()):
81
- progress(0, desc="Đang khởi tạo...")
82
-
83
- if not video_file:
84
- return "Vui lòng tải lên tệp video.", None, None, None
85
-
86
- try:
87
- # Kiểm tra kích thước tệp video
88
- video_size = os.path.getsize(video_file) / (1024 * 1024) # Chuyển sang MB
89
- if video_size > MAX_VIDEO_SIZE_MB:
90
- return f"Tệp video lớn hơn {MAX_VIDEO_SIZE_MB} MB. Vui lòng tải lên tệp nhỏ hơn.", None, None, None
91
-
92
- progress(0.2, desc="Đang trích xuất audio")
93
- audio_path = extract_audio_from_video(video_file)
94
-
95
- progress(0.4, desc="Đang trích xuất thumbnails")
96
- thumbnails = extract_thumbnails(video_file)
97
-
98
- progress(0.6, desc="Đang chuyển audio thành văn bản")
99
- transcription_text = transcribe_audio(audio_path)
100
-
101
- progress(0.8, desc="Đang tạo bài đăng mạng xã hội")
102
- prompt = create_prompt(social_media_type, transcription_text)
103
- social_media_post = call_groq_api(prompt)
104
-
105
- # Xóa tệp audio tạm
106
- if os.path.exists(audio_path):
107
- os.remove(audio_path)
108
-
109
- return social_media_post, thumbnails, transcription_text, audio_path
110
-
111
- except Exception as e:
112
- return f"Đã xảy ra lỗi: {str(e)}", None, None, None
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Giao diện Gradio
115
- def gradio_interface():
116
- with gr.Blocks(theme=gr.themes.Base()) as demo:
117
- gr.Markdown("### Công cụ Tạo Bài Đăng Mạng Xã Hội")
118
- video_input = gr.File(label="Tải lên Video", file_types=[".mp4", ".avi", ".mov"])
119
- social_media_type = gr.Radio(
120
- choices=["X (Twitter)", "Facebook", "LinkedIn", "Instagram"],
121
- value="X (Twitter)",
122
- label="Nền tảng"
123
- )
124
- generate_btn = gr.Button("Tạo Bài Đăng")
125
- output = gr.Textbox(label="Bài Đăng Đã Tạo")
126
- generate_btn.click(
127
- fn=process_and_generate_post,
128
- inputs=[video_input, social_media_type],
129
- outputs=[output]
130
- )
131
- return demo
132
- # Khởi chạy giao diện Gradio
133
  if __name__ == "__main__":
134
- demo = gradio_interface()
135
- demo.launch(share=True)
 
1
+ import cv2
2
+ import mediapipe as mp
 
 
 
3
  import numpy as np
4
+ import gradio as gr
5
+ import base64
6
+ import time
7
 
8
+ # Initialize MediaPipe Selfie Segmentation
9
+ mp_selfie_segmentation = mp.solutions.selfie_segmentation
10
+ segmentation = mp_selfie_segmentation.SelfieSegmentation(model_selection=1)
 
 
11
 
12
+ # Global settings
13
+ settings = {
14
+ "seg_enabled": True,
15
+ "blur_bg": False,
16
+ "set_bg": False,
17
+ "set_color": False,
18
+ "bg_color": (0, 0, 0), # BGR
19
+ "blur_intensity": 15
20
+ }
21
+ bg_image = None
22
 
23
+ def process_frame(frame, seg_enabled, blur_bg, set_bg, set_color, bg_color, blur_intensity, custom_image=None):
24
+ global bg_image
25
+ settings.update({
26
+ "seg_enabled": seg_enabled,
27
+ "blur_bg": blur_bg,
28
+ "set_bg": set_bg,
29
+ "set_color": set_color,
30
+ "bg_color": tuple(map(int, bg_color.split(","))) if set_color else (0, 0, 0),
31
+ "blur_intensity": blur_intensity
32
+ })
33
 
34
+ if custom_image is not None and set_bg:
35
+ bg_image = custom_image
 
 
 
 
 
 
 
 
36
 
37
+ process_start = time.time()
38
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
+ result = segmentation.process(frame_rgb)
40
+ mask = result.segmentation_mask
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Create alpha mask
43
+ alpha = mask > 0.5
44
+ alpha = alpha.astype(np.uint8) * 255
45
+ alpha = cv2.merge([alpha, alpha, alpha])
 
 
 
 
 
 
 
 
46
 
47
+ output_frame = frame.copy()
 
 
 
 
 
 
 
 
 
 
48
 
49
+ if settings["seg_enabled"]:
50
+ if settings["blur_bg"]:
51
+ bg = cv2.resize(frame, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_LINEAR)
52
+ ksize = settings["blur_intensity"]
53
+ if ksize % 2 == 0:
54
+ ksize -= 1
55
+ bg = cv2.GaussianBlur(bg, (ksize, ksize), 0)
56
+ bg = cv2.resize(bg, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR)
57
+ output_frame = np.where(alpha == 255, frame, bg)
58
+ elif settings["set_bg"] and bg_image is not None:
59
+ if bg_image.shape[:2] != frame.shape[:2]:
60
+ bg_image = cv2.resize(bg_image, (frame.shape[1], frame.shape[0]))
61
+ output_frame = np.where(alpha == 255, frame, bg_image)
62
+ elif settings["set_color"]:
63
+ bg = np.full_like(frame, settings["bg_color"])
64
+ output_frame = np.where(alpha == 255, frame, bg)
65
+ else:
66
+ bg = np.zeros_like(frame)
67
+ output_frame = np.where(alpha == 255, frame, bg)
68
+
69
+ process_time = (time.time() - process_start) * 1000
70
+ return output_frame, f"{process_time:.2f} ms"
71
+
72
+ # Gradio interface
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("# AI Background Remover")
75
+ with gr.Row():
76
+ with gr.Column():
77
+ webcam = gr.Webcam(label="Live Video")
78
+ output_image = gr.Image(label="Processed Output")
79
+ seg_enabled = gr.Checkbox(label="Enable Background Removal", value=True)
80
+ blur_bg = gr.Checkbox(label="Blur Background")
81
+ set_bg = gr.Checkbox(label="Custom Image Background")
82
+ custom_image = gr.Image(label="Upload Custom Background")
83
+ set_color = gr.Checkbox(label="Solid Color Background")
84
+ bg_color = gr.Textbox(label="Background Color (R,G,B)", value="0,0,0")
85
+ blur_intensity = gr.Slider(label="Blur Intensity", minimum=5, maximum=25, value=15, step=2)
86
+ processing_time = gr.Textbox(label="Processing Time", value="0 ms")
87
+
88
+ webcam.stream(
89
+ fn=process_frame,
90
+ inputs=[webcam, seg_enabled, blur_bg, set_bg, set_color, bg_color, blur_intensity, custom_image],
91
+ outputs=[output_image, processing_time]
92
+ )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if __name__ == "__main__":
95
+ demo.launch()