Spaces:
Sleeping
Sleeping
| # ──────────────────────────────────────────────────────────────────────────────── | |
| # app.py (CPU-only 版:先加载 float32 基座 LLaMA-8B,再叠入 LoRA Adapter) | |
| # ──────────────────────────────────────────────────────────────────────────────── | |
| import gradio as gr | |
| import torch | |
| import gc | |
| import os | |
| from transformers import AutoTokenizer, LlamaForCausalLM | |
| from peft import PeftModel | |
| # ─────────────────────── 1. 释放可能的显存/内存 ─────────────────────── | |
| # 对于 CPU-only,可以留着,也不会报错 | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ─────────────────────── 2. 配置区域 ─────────────────────── | |
| # (A)Adapter 仓库 ID:LoRA 权重所在的 Hugging Face Repo | |
| # 这个仓库里只有 adapter_model.safetensors + adapter_config.json + tokenizer 文件 | |
| ADAPTER_REPO = "yxccai/text-style-converter" | |
| # (B)基座模型 ID(去掉了 -bnb-4bit 后缀,改用 float32 版) | |
| # 原 adapter_config.json 里提到的 "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" | |
| # 在 CPU-only 环境下不能加载 4bit bitsandbytes,所以我们要改为: | |
| # "unsloth/deepseek-r1-distill-llama-8b" | |
| # 如果您本地没有这个仓库,可以换成“decapoda-research/llama-7b-hf”或其他您能在 CPU 上跑通的模型。 | |
| BASE_MODEL_ID = "unsloth/deepseek-r1-distill-llama-8b" | |
| # 全局变量:Tokenizer + Model | |
| tokenizer = None | |
| model = None | |
| # ─────────────────────── 3. 加载模型的函数 ─────────────────────── | |
| def load_model(): | |
| """ | |
| CPU-only 逻辑: | |
| 1. 先从 Adapter 仓库加载 Tokenizer(里面有 tokenizer.json 等文件)。 | |
| 2. 再用 LlamaForCausalLM 从 float32 版基座模型加载到 CPU。 | |
| 3. 然后用 PeftModel.from_pretrained(...) 将 LoRA Adapter 权重叠加到基座上。 | |
| """ | |
| global tokenizer, model | |
| # 如果 tokenizer/model 还未加载,则执行加载逻辑 | |
| if tokenizer is None or model is None: | |
| try: | |
| # ── 3.1 加载 Tokenizer ── | |
| print("正在加载 Tokenizer(来自 LoRA 仓库)…") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| ADAPTER_REPO, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| ) | |
| # 如果 pad_token 不存在,就用 eos_token 代替 | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ── 3.2 加载基座模型(LLaMA float32 → CPU) ── | |
| print(f"正在加载基座模型:{BASE_MODEL_ID} (float32 → CPU)…") | |
| # 注意:这里用 torch_dtype=torch.float32, device_map="cpu"。如果 Model 太大、内存不足,会 OOM。 | |
| base_model = LlamaForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, # 尽量启用低内存占用模式 | |
| trust_remote_code=True, | |
| ) | |
| print("→ 基座模型加载完成。(注意检查是否被系统 OOM)") | |
| # ── 3.3 用 PeftModel 叠加 LoRA Adapter ── | |
| print(f"正在叠加 LoRA Adapter:{ADAPTER_REPO}…") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| ADAPTER_REPO, | |
| device_map="cpu", # CPU-only 环境 | |
| torch_dtype=torch.float32, # 同样使用 float32 | |
| ) | |
| print("→ LoRA Adapter 已叠加成功。") | |
| # (可选)不想更新基座所有参数时,把 base_model 的参数都冻结: | |
| # model.eval() | |
| # for param in model.base_model.parameters(): | |
| # param.requires_grad = False | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"模型加载失败: {str(e)}") | |
| return False | |
| return True | |
| # ─────────────────────── 4. 文本生成函数 ─────────────────────── | |
| def convert_text_style(input_text: str) -> str: | |
| """ | |
| 输入一句书面化/技术性的中文,让模型把它转换成自然、口语化的表达方式。 | |
| """ | |
| if not input_text or input_text.strip() == "": | |
| return "请输入要转换的文本。" | |
| # 确保模型已加载 | |
| if not load_model(): | |
| return "模型加载失败,请稍后重试。" | |
| try: | |
| # 拼一个简单的 Prompt | |
| prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。 | |
| ### 输入文本: | |
| {input_text} | |
| ### 输出文本: | |
| """ | |
| # 分词 & 转 torch.Tensor | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding=True, | |
| ) | |
| # 全部放到 CPU 上 | |
| inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
| # 生成 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=1, | |
| ) | |
| # 解码并抽取结果 | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "### 输出文本:" in full_text: | |
| return full_text.split("### 输出文本:")[-1].strip() | |
| return full_text[len(prompt) :].strip() | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"生成过程中出现错误: {str(e)}" | |
| # ─────────────────────── 5. Gradio 界面配置 ─────────────────────── | |
| iface = gr.Interface( | |
| fn=convert_text_style, | |
| inputs=gr.Textbox( | |
| label="输入文本", placeholder="请输入需要转换为口语化的书面文本...", lines=3 | |
| ), | |
| outputs=gr.Textbox(label="输出文本", lines=4), | |
| title="中文文本风格转换API", | |
| description="将书面化、技术性文本转换为自然、口语化表达", | |
| examples=[ | |
| ["乙醇的检测方法包括酸碱度检查。"], | |
| ["本品为薄膜衣片,除去包衣后显橙红色。"], | |
| ], | |
| cache_examples=False, | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| print("启动 Gradio 应用…") | |
| # 纯 CPU 环境下,server_name 可以保持默认 "0.0.0.0",port 也是 7860 | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=False) | |