Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import gc | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os | |
| # 清理内存 | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # 设置环境变量 | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| # 模型名称 | |
| model_name = "yxccai/text-style-converter" | |
| # 全局变量存储模型 | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| """延迟加载模型""" | |
| global tokenizer, model | |
| if tokenizer is None or model is None: | |
| try: | |
| print("正在加载tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_fast=False # 使用慢速tokenizer减少内存 | |
| ) | |
| print("正在加载模型...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, # 使用半精度 | |
| device_map="cpu", # 强制使用CPU | |
| low_cpu_mem_usage=True, # 启用低内存模式 | |
| trust_remote_code=True, | |
| load_in_8bit=False, # 在CPU上不使用量化 | |
| offload_folder="./offload", # 设置offload文件夹 | |
| ) | |
| # 设置pad_token | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("模型加载完成!") | |
| except Exception as e: | |
| print(f"模型加载失败: {str(e)}") | |
| return False | |
| return True | |
| def convert_text_style(input_text): | |
| """文本风格转换函数""" | |
| if not input_text.strip(): | |
| return "请输入要转换的文本" | |
| # 检查模型是否加载 | |
| if not load_model(): | |
| return "模型加载失败,请稍后重试" | |
| try: | |
| prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。 | |
| ### 输入文本: | |
| {input_text} | |
| ### 输出文本: | |
| """ | |
| # 编码输入 | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, # 限制输入长度 | |
| truncation=True, | |
| padding=True | |
| ) | |
| # 生成回答 | |
| with torch.no_grad(): # 不计算梯度节省内存 | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=300, # 减少生成长度 | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2 | |
| ) | |
| # 解码输出 | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 提取生成的部分 | |
| if "### 输出文本:" in full_response: | |
| response = full_response.split("### 输出文本:")[-1].strip() | |
| else: | |
| response = full_response[len(prompt):].strip() | |
| # 清理内存 | |
| del inputs, outputs | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return response if response else "抱歉,未能生成有效回答" | |
| except Exception as e: | |
| return f"生成过程中出现错误: {str(e)}" | |
| # 创建Gradio接口 - 修复版本兼容性问题 | |
| iface = gr.Interface( | |
| fn=convert_text_style, | |
| inputs=gr.Textbox( | |
| label="输入文本", | |
| placeholder="请输入需要转换为口语化的书面文本...", | |
| lines=3 | |
| ), | |
| outputs=gr.Textbox( | |
| label="输出文本", | |
| lines=3 | |
| ), | |
| title="中文文本风格转换API", | |
| description="将书面化、技术性文本转换为自然、口语化表达", | |
| examples=[ | |
| ["乙醇的检测方法包括酸碱度检查。"], | |
| ["本品为薄膜衣片,除去包衣后显橙红色。"] | |
| ], | |
| cache_examples=False, # 不缓存示例 | |
| flagging_mode="never" # 修复:使用flagging_mode替代allow_flagging | |
| ) | |
| # 启动应用 - 移除不兼容的参数 | |
| if __name__ == "__main__": | |
| print("正在启动应用...") | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False | |
| # 移除了enable_queue和max_threads参数 | |
| ) |