Spaces:
Running
Running
File size: 25,429 Bytes
8587b71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 |
import os
import json
import traceback
from loguru import logger
# import tiktoken
from typing import List, Dict
from datetime import datetime
from openai import OpenAI
import requests
import time
class BaseGenerator:
def __init__(self, model_name: str, api_key: str, prompt: str):
self.model_name = model_name
self.api_key = api_key
self.base_prompt = prompt
self.conversation_history = []
self.chunk_overlap = 50
self.last_chunk_ending = ""
self.default_params = {
"temperature": 0.7,
"max_tokens": 500,
"top_p": 0.9,
"frequency_penalty": 0.3,
"presence_penalty": 0.5
}
def _try_generate(self, messages: list, params: dict = None) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
try:
response = self._generate(messages, params or self.default_params)
return self._process_response(response)
except Exception as e:
if attempt == max_attempts - 1:
raise
logger.warning(f"Generation attempt {attempt + 1} failed: {str(e)}")
continue
return ""
def _generate(self, messages: list, params: dict) -> any:
raise NotImplementedError
def _process_response(self, response: any) -> str:
return response
def generate_script(self, scene_description: str, word_count: int) -> str:
"""生成脚本的通用方法"""
prompt = f"""{self.base_prompt}
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述:{scene_description}
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
不要出现除了文案以外的其他任何内容;
严格字数要求:{word_count}字,允许误差±5字。"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
try:
generated_script = self._try_generate(messages, self.default_params)
# 更新上下文
if generated_script:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
except Exception as e:
logger.error(f"Script generation failed: {str(e)}")
raise
class OpenAIGenerator(BaseGenerator):
"""OpenAI API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt)
base_url = base_url or f"https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.max_tokens = 5000
# OpenAI特定参数
self.default_params = {
**self.default_params,
"stream": False,
"user": "script_generator"
}
# # 初始化token计数器
# try:
# self.encoding = tiktoken.encoding_for_model(self.model_name)
# except KeyError:
# logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
# self.encoding = tiktoken.get_encoding("cl100k_base")
def _generate(self, messages: list, params: dict) -> any:
"""实现OpenAI特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"OpenAI generation error: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理OpenAI的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from OpenAI API")
return response.choices[0].message.content.strip()
def _count_tokens(self, messages: list) -> int:
"""计算token数量"""
num_tokens = 0
for message in messages:
num_tokens += 3
for key, value in message.items():
num_tokens += len(self.encoding.encode(str(value)))
if key == "role":
num_tokens += 1
num_tokens += 3
return num_tokens
class GeminiGenerator(BaseGenerator):
"""原生Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
self.client = None
# 原生Gemini API参数
self.default_params = {
"temperature": self.default_params["temperature"],
"topP": self.default_params["top_p"],
"topK": 40,
"maxOutputTokens": 4000,
"candidateCount": 1,
"stopSequences": []
}
class GeminiOpenAIGenerator(BaseGenerator):
"""OpenAI兼容的Gemini代理生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
if not base_url:
raise ValueError("OpenAI兼容的Gemini代理必须提供base_url")
self.base_url = base_url.rstrip('/')
# 使用OpenAI兼容接口
from openai import OpenAI
self.client = OpenAI(
api_key=api_key,
base_url=base_url
)
# OpenAI兼容接口参数
self.default_params = {
"temperature": self.default_params["temperature"],
"max_tokens": 4000,
"stream": False
}
def _generate(self, messages: list, params: dict) -> any:
"""实现OpenAI兼容Gemini代理的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理OpenAI兼容接口的响应"""
if not response or not response.choices:
raise ValueError("OpenAI兼容Gemini代理返回无效响应")
return response.choices[0].message.content.strip()
def _generate(self, messages: list, params: dict) -> any:
"""实现原生Gemini API的生成逻辑"""
max_retries = 3
for attempt in range(max_retries):
try:
# 转换消息格式为Gemini格式
prompt = "\n".join([m["content"] for m in messages])
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": params,
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model_name}:generateContent"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={
"Content-Type": "application/json",
"x-goog-api-key": self.api_key
},
timeout=120
)
if response.status_code == 429:
# 处理限流
wait_time = 65 if attempt == 0 else 30
logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
if response.status_code == 400:
raise Exception(f"请求参数错误: {response.text}")
elif response.status_code == 403:
raise Exception(f"API密钥无效或权限不足: {response.text}")
elif response.status_code != 200:
raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}")
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
if attempt < max_retries - 1:
logger.warning("原生Gemini API 返回无效响应,等待30秒后重试...")
time.sleep(30)
continue
else:
raise Exception("原生Gemini API返回无效响应,可能触发了安全过滤")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise Exception("内容被Gemini安全过滤器阻止")
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, data):
self.data = data
candidate = data["candidates"][0]
if "content" in candidate and "parts" in candidate["content"]:
self.text = ""
for part in candidate["content"]["parts"]:
if "text" in part:
self.text += part["text"]
else:
self.text = ""
return CompatibleResponse(response_data)
except requests.exceptions.RequestException as e:
if attempt < max_retries - 1:
logger.warning(f"网络请求失败,等待30秒后重试: {str(e)}")
time.sleep(30)
continue
else:
logger.error(f"原生Gemini API请求失败: {str(e)}")
raise
except Exception as e:
if attempt < max_retries - 1 and "429" in str(e):
logger.warning("原生Gemini API 触发限流,等待65秒后重试...")
time.sleep(65)
continue
else:
logger.error(f"原生Gemini 生成文案错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理原生Gemini API的响应"""
if not response or not response.text:
raise ValueError("原生Gemini API返回无效响应")
return response.text.strip()
class QwenGenerator(BaseGenerator):
"""阿里云千问 API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url=base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
# Qwen特定参数
self.default_params = {
**self.default_params,
"stream": False,
"user": "script_generator"
}
def _generate(self, messages: list, params: dict) -> any:
"""实现千问特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"Qwen generation error: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理千问的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from Qwen API")
return response.choices[0].message.content.strip()
class MoonshotGenerator(BaseGenerator):
"""Moonshot API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url=base_url or "https://api.moonshot.cn/v1"
)
# Moonshot特定参数
self.default_params = {
**self.default_params,
"stream": False,
"stop": None,
"user": "script_generator",
"tools": None
}
def _generate(self, messages: list, params: dict) -> any:
"""实现Moonshot特定的生成逻辑,包含429误重试机制"""
while True:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
error_str = str(e)
if "Error code: 429" in error_str:
logger.warning("Moonshot API 触发限流,等待65秒后重试...")
time.sleep(65) # 等待65秒后重试
continue
else:
logger.error(f"Moonshot generation error: {error_str}")
raise
def _process_response(self, response: any) -> str:
"""处理Moonshot的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from Moonshot API")
return response.choices[0].message.content.strip()
class DeepSeekGenerator(BaseGenerator):
"""DeepSeek API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url=base_url or "https://api.deepseek.com"
)
# DeepSeek特定参数
self.default_params = {
**self.default_params,
"stream": False,
"user": "script_generator"
}
def _generate(self, messages: list, params: dict) -> any:
"""实现DeepSeek特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name, # deepseek-chat 或 deepseek-coder
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"DeepSeek generation error: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理DeepSeek的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from DeepSeek API")
return response.choices[0].message.content.strip()
class ScriptProcessor:
def __init__(self, model_name: str, api_key: str = None, base_url: str = None, prompt: str = None, video_theme: str = ""):
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url
self.video_theme = video_theme
self.prompt = prompt or self._get_default_prompt()
# 根据模型名称选择对应的生成器
logger.info(f"文本 LLM 提供商: {model_name}")
if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'moonshot' in model_name.lower():
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'deepseek' in model_name.lower():
self.generator = DeepSeekGenerator(model_name, self.api_key, self.prompt, self.base_url)
else:
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt, self.base_url)
def _get_default_prompt(self) -> str:
return f"""
你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让主题为 《{self.video_theme}》 的视频既有趣又富有传播力。
你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
目标受众:热爱生活、追求独特体验的18-35岁年轻人
文案风格:基于HKRR理论 + 段子手精神
主题:{self.video_theme}
【创作核心理念】
1. 敢于用"温和的违反"制造笑点,但不能过于冒犯
2. 巧妙运用中国式幽默,让观众会心一笑
3. 保持轻松愉快的叙事基调
【爆款内容四要素】
【快乐元素 Happy】
1. 用调侃的语气描述画面
2. 巧妙植入网络流行梗,增加内容的传播性
3. 适时自嘲,展现真实且有趣的一面
【知识价值 Knowledge】
1. 用段子手的方式解释专业知识
2. 在幽默中传递实用的生活常识
【情感共鸣 Resonance】
1. 描述"真实但夸张"的环境描述
2. 把对自然的感悟融入俏皮话中
3. 用接地气的表达方式拉近与观众距离
【节奏控制 Rhythm】
1. 像讲段子一样,注意铺垫和包袱的节奏
2. 确保每段都有笑点,但不强求
3. 段落结尾干净利落,不拖泥带水
【连贯性要求】
1. 新生成的内容必须自然衔接上一段文案的结尾
2. 使用恰当的连接词和过渡语,确保叙事流畅
3. 保持人物视角和语气的一致性
4. 避免重复上一段已经提到的信息
5. 确保情节的逻辑连续性
我会按顺序提供多段视频画面描述。请创作既搞笑又能火爆全网的口播文案。
记住:要敢于用"温和的违反"制造笑点,但要把握好尺度,让观众在轻松愉快中感受到乐趣。"""
def calculate_duration_and_word_count(self, time_range: str) -> int:
"""
计算时间范围的持续时长并估算合适的字数
Args:
time_range: 时间范围字符串,格式为 "HH:MM:SS,mmm-HH:MM:SS,mmm"
例如: "00:00:50,100-00:01:21,500"
Returns:
int: 估算的合适字数
基于经验公式: 每0.35秒可以说一个字
例如: 10秒可以说约28个字 (10/0.35≈28.57)
"""
try:
start_str, end_str = time_range.split('-')
def time_to_seconds(time_str: str) -> float:
"""
将时间字符串转换为秒数(带毫秒精度)
Args:
time_str: 时间字符串,格式为 "HH:MM:SS,mmm"
例如: "00:00:50,100" 表示50.1秒
Returns:
float: 转换后的秒数(带毫秒)
"""
try:
# 处理毫秒部分
time_part, ms_part = time_str.split(',')
hours, minutes, seconds = map(int, time_part.split(':'))
milliseconds = int(ms_part)
# 转换为秒
total_seconds = (hours * 3600) + (minutes * 60) + seconds + (milliseconds / 1000)
return total_seconds
except ValueError as e:
logger.warning(f"时间格式解析错误: {time_str}, error: {e}")
return 0.0
# 计算开始和结束时间的秒数
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
# 计算持续时间(秒)
duration = end_seconds - start_seconds
# 根据经验公式计算字数: 每0.5秒一个字
word_count = int(duration / 0.4)
# 确保字数在合理范围内
word_count = max(10, min(word_count, 500)) # 限制在10-500字之间
logger.debug(f"时间范围 {time_range} 的持续时间为 {duration:.3f}秒, 估算字数: {word_count}")
return word_count
except Exception as e:
logger.warning(f"字数计算错误: {traceback.format_exc()}")
return 100 # 发生错误时返回默认字数
def process_frames(self, frame_content_list: List[Dict]) -> List[Dict]:
for frame_content in frame_content_list:
word_count = self.calculate_duration_and_word_count(frame_content["timestamp"])
script = self.generator.generate_script(frame_content["picture"], word_count)
frame_content["narration"] = script
frame_content["OST"] = 2
logger.info(f"时间范围: {frame_content['timestamp']}, 建议字数: {word_count}")
logger.info(script)
self._save_results(frame_content_list)
return frame_content_list
def _save_results(self, frame_content_list: List[Dict]):
"""保存处理结果,并添加新的时间戳"""
try:
def format_timestamp(seconds: float) -> str:
"""将秒数转换为 HH:MM:SS,mmm 格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds_remainder = seconds % 60
whole_seconds = int(seconds_remainder)
milliseconds = int((seconds_remainder - whole_seconds) * 1000)
return f"{hours:02d}:{minutes:02d}:{whole_seconds:02d},{milliseconds:03d}"
# 计算新的时间戳
current_time = 0.0 # 当前时间点(秒,包含毫秒)
for frame in frame_content_list:
# 获取原始时间戳的持续时间
start_str, end_str = frame['timestamp'].split('-')
def time_to_seconds(time_str: str) -> float:
"""将时间字符串转换为秒数(包含毫秒)"""
try:
if ',' in time_str:
time_part, ms_part = time_str.split(',')
ms = float(ms_part) / 1000
else:
time_part = time_str
ms = 0
parts = time_part.split(':')
if len(parts) == 3: # HH:MM:SS
h, m, s = map(float, parts)
seconds = h * 3600 + m * 60 + s
elif len(parts) == 2: # MM:SS
m, s = map(float, parts)
seconds = m * 60 + s
else: # SS
seconds = float(parts[0])
return seconds + ms
except Exception as e:
logger.error(f"时间格式转换错误 {time_str}: {str(e)}")
return 0.0
# 计算当前片段的持续时间
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds
# 设置新的时间戳
new_start = format_timestamp(current_time)
new_end = format_timestamp(current_time + duration)
frame['new_timestamp'] = f"{new_start}-{new_end}"
# 更新当前时间点
current_time += duration
# 保存结果
file_name = f"storage/json/step2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w', encoding='utf-8') as file:
json.dump(frame_content_list, file, ensure_ascii=False, indent=4)
logger.info(f"保存脚本成功,总时长: {format_timestamp(current_time)}")
except Exception as e:
logger.error(f"保存结果时发生错误: {str(e)}\n{traceback.format_exc()}")
raise
|