import os import torch import psutil import gc import time import json import warnings # 设置环境变量和警告过滤 os.environ["TRANSFORMERS_VERBOSITY"] = "error" from fastapi import FastAPI, File, UploadFile, Form, Request from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from PIL import Image import io import base64 import logging # 尝试导入 Qwen-VL 相关模块,如果失败则提供错误信息 try: from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, GenerationConfig from qwen_vl_utils import process_vision_info QWEN_VL_AVAILABLE = True IMPORT_ERROR = None except ImportError as e: QWEN_VL_AVAILABLE = False IMPORT_ERROR = str(e) # 创建占位符类以避免运行时错误 class Qwen2VLForConditionalGeneration: pass class AutoTokenizer: pass class AutoProcessor: pass class GenerationConfig: pass def process_vision_info(*args, **kwargs): raise ImportError("qwen_vl_utils not available") # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 数据模型 class AnalyzeRequest(BaseModel): image: str # base64 编码的图片 prompt: str = "请描述这张图片的内容" # 提示词/问题 class AnalyzeResponse(BaseModel): success: bool prompt: str response: str processing_time: float image_info: dict = None error: str = None app = FastAPI(title="Qwen-VL PicExam API", description="基于 Qwen2-VL-2B-Instruct 的图像理解 API") # 添加 CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 在生产环境中应该限制为特定域名 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 挂载静态文件 app.mount("/static", StaticFiles(directory="static"), name="static") # 全局变量存储模型和处理器 model = None processor = None tokenizer = None def load_model(): """加载 Qwen2-VL-2B-Instruct 模型(CPU 版本,适合 16GB 内存)""" global model, processor, tokenizer # 检查依赖是否可用 if not QWEN_VL_AVAILABLE: logger.error(f"Qwen-VL 依赖不可用: {IMPORT_ERROR}") logger.error("请安装缺失的依赖: pip install torchvision qwen-vl-utils") return False try: logger.info("开始加载 Qwen2-VL-2B-Instruct 模型...") memory = psutil.virtual_memory() logger.info(f"模型加载前内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") model_name = "Qwen/Qwen2-VL-2B-Instruct" # 设置环境变量优化内存使用 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" os.environ["TOKENIZERS_PARALLELISM"] = "false" # 避免分词器并行导致的内存问题 # 加载处理器和分词器 logger.info("正在加载 Processor...") processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) memory = psutil.virtual_memory() logger.info(f"Processor 加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") logger.info("正在加载 Tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) memory = psutil.virtual_memory() logger.info(f"Tokenizer 加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") # 加载模型到 CPU,使用内存优化配置 logger.info("正在加载核心模型... (此步骤最消耗内存)") model = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.float16, # 使用 float16 减少内存使用 device_map="cpu", # 强制使用 CPU low_cpu_mem_usage=True, # 低内存使用模式 trust_remote_code=True, # 额外的内存优化选项 use_cache=False, # 禁用 KV 缓存以节省内存 attn_implementation="eager", # 使用 eager attention 实现 ) memory = psutil.virtual_memory() logger.info(f"核心模型加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") # 设置为评估模式 model.eval() # 清理不必要的内存 if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("模型加载成功!") logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") return True except Exception as e: logger.error(f"模型加载失败: {str(e)}", exc_info=True) # 记录完整的堆栈跟踪 memory = psutil.virtual_memory() logger.error(f"失败时内存状态: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") return False # 启动时加载模型 @app.on_event("startup") async def startup_event(): """应用启动时加载模型""" success = load_model() if not success: logger.error("模型加载失败,应用可能无法正常工作") @app.get("/") def api_documentation(): """API 文档和端点说明""" return { "service": "Qwen-VL PicExam API", "description": "基于 Qwen2-VL-2B-Instruct 的图像理解 API,支持 16GB 内存 + CPU 推理", "version": "1.0.0", "model": "Qwen2-VL-2B-Instruct", "status": { "service": "running", "model_loaded": model is not None, "inference_mode": "CPU", "dependencies_available": QWEN_VL_AVAILABLE, "import_error": IMPORT_ERROR if not QWEN_VL_AVAILABLE else None }, "endpoints": { "GET /": { "description": "获取 API 文档和端点信息", "response": "JSON 格式的 API 说明" }, "GET /health": { "description": "健康检查接口", "response": "服务状态信息" }, "GET /web": { "description": "Web 界面", "response": "HTML 页面,提供图形化操作界面" }, "POST /analyze_image": { "description": "分析上传的图片文件", "parameters": { "image": "图片文件 (multipart/form-data)", "question": "关于图片的问题 (可选,默认为描述图片内容)" }, "example": "curl -X POST '/analyze_image' -F 'image=@photo.jpg' -F 'question=这张图片中有什么?'" }, "POST /analyze_image_base64": { "description": "分析 base64 编码的图片", "parameters": { "image_base64": "base64 编码的图片数据", "question": "关于图片的问题 (可选)" }, "example": "curl -X POST '/analyze_image_base64' -F 'image_base64=...' -F 'question=描述这张图片'" }, "POST /analyze": { "description": "简化的图片分析接口 (JSON 格式)", "parameters": { "image": "base64 编码的图片数据", "prompt": "提示词/问题" }, "example": "curl -X POST '/analyze' -H 'Content-Type: application/json' -d '{\"image\":\"data:image/jpeg;base64,...\",\"prompt\":\"描述图片\"}'" }, "GET /memory_status": { "description": "获取内存使用状态", "response": "系统内存和模型内存使用情况" }, "POST /clear_cache": { "description": "清理内存缓存", "response": "缓存清理结果" } }, "usage_examples": { "curl_file_upload": "curl -X POST 'http://localhost:7860/analyze_image' -F 'image=@your_image.jpg' -F 'question=请描述这张图片'", "curl_base64": "curl -X POST 'http://localhost:7860/analyze_image_base64' -F 'image_base64=...' -F 'question=这张图片中有什么?'", "curl_json": "curl -X POST 'http://localhost:7860/analyze' -H 'Content-Type: application/json' -d '{\"image\":\"...\",\"prompt\":\"请详细描述这张图片的内容\"}'" }, "supported_formats": ["JPEG", "PNG", "WebP", "BMP", "GIF"], "memory_requirements": "16GB RAM recommended for optimal performance", "inference_time": "5-15 seconds per image (depends on CPU)", "documentation": "Visit /docs for interactive API documentation" } @app.get("/health") def health_check(): """简单的健康检查接口""" return { "status": "healthy", "service": "Qwen-VL PicExam API", "model_loaded": model is not None, "dependencies_available": QWEN_VL_AVAILABLE, "timestamp": time.time() } @app.get("/dependencies") def check_dependencies(): """检查依赖状态""" # 基础依赖检查 basic_dependencies = { "fastapi": False, "uvicorn": False, "multipart": False, "PIL": False, "pydantic": False } # Qwen-VL 依赖检查 qwen_dependencies = { "torch": False, "torchvision": False, "transformers": False, "qwen_vl_utils": False } # 检查基础依赖 try: import fastapi basic_dependencies["fastapi"] = True except ImportError: pass try: import uvicorn basic_dependencies["uvicorn"] = True except ImportError: pass try: import multipart basic_dependencies["multipart"] = True except ImportError: pass try: from PIL import Image basic_dependencies["PIL"] = True except ImportError: pass try: import pydantic basic_dependencies["pydantic"] = True except ImportError: pass # 检查 Qwen-VL 依赖 try: import torch qwen_dependencies["torch"] = True except ImportError: pass try: import torchvision qwen_dependencies["torchvision"] = True except ImportError: pass try: import transformers qwen_dependencies["transformers"] = True except ImportError: pass try: import qwen_vl_utils qwen_dependencies["qwen_vl_utils"] = True except ImportError: pass # 计算缺失的依赖 missing_basic = [dep for dep, available in basic_dependencies.items() if not available] missing_qwen = [dep for dep, available in qwen_dependencies.items() if not available] return { "basic_dependencies": basic_dependencies, "qwen_dependencies": qwen_dependencies, "basic_ready": len(missing_basic) == 0, "qwen_ready": len(missing_qwen) == 0, "missing_basic": missing_basic, "missing_qwen": missing_qwen, "installation_commands": { "basic": f"pip install {' '.join(missing_basic)}" if missing_basic else "All basic dependencies installed", "qwen": f"pip install {' '.join(missing_qwen)}" if missing_qwen else "All Qwen-VL dependencies installed" }, "qwen_vl_available": QWEN_VL_AVAILABLE, "import_error": IMPORT_ERROR if not QWEN_VL_AVAILABLE else None, "note": "基础依赖用于 API 运行,Qwen-VL 依赖用于实际的图像分析功能" } @app.post("/analyze_image") async def analyze_image( image: UploadFile = File(...), question: str = Form("请描述这张图片的内容") ): """ 分析上传的图片并回答问题 Args: image: 上传的图片文件 question: 关于图片的问题(默认为描述图片内容) Returns: JSON 响应包含分析结果 """ if not QWEN_VL_AVAILABLE: return JSONResponse( status_code=503, content={ "error": "Qwen-VL 依赖不可用", "details": IMPORT_ERROR, "solution": "请安装缺失的依赖: pip install torchvision qwen-vl-utils av opencv-python-headless" } ) if model is None or processor is None: return JSONResponse( status_code=503, content={"error": "模型未加载,请稍后重试"} ) try: # 读取图片 image_bytes = await image.read() pil_image = Image.open(io.BytesIO(image_bytes)) # 确保图片是 RGB 格式 if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # 准备消息格式 messages = [ { "role": "user", "content": [ { "type": "image", "image": pil_image, }, {"type": "text", "text": question}, ], } ] # 处理输入 text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # 生成回答 with torch.no_grad(): # 使用 GenerationConfig 来避免警告并确保参数正确 generation_config = GenerationConfig( max_new_tokens=512, do_sample=False, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True ) generated_ids = model.generate(**inputs, generation_config=generation_config) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return { "success": True, "question": question, "answer": output_text, "image_info": { "filename": image.filename, "size": f"{pil_image.size[0]}x{pil_image.size[1]}", "mode": pil_image.mode } } except Exception as e: logger.error(f"图片分析失败: {str(e)}") return JSONResponse( status_code=500, content={"error": f"图片分析失败: {str(e)}"} ) @app.post("/analyze_image_base64") async def analyze_image_base64( image_base64: str = Form(...), question: str = Form("请描述这张图片的内容") ): """ 分析 base64 编码的图片并回答问题 Args: image_base64: base64 编码的图片数据 question: 关于图片的问题 Returns: JSON 响应包含分析结果 """ if not QWEN_VL_AVAILABLE: return JSONResponse( status_code=503, content={ "error": "Qwen-VL 依赖不可用", "details": IMPORT_ERROR, "solution": "请安装缺失的依赖: pip install torchvision qwen-vl-utils av opencv-python-headless" } ) if model is None or processor is None: return JSONResponse( status_code=503, content={"error": "模型未加载,请稍后重试"} ) try: # 解码 base64 图片 if image_base64.startswith('data:image'): # 移除 data:image/xxx;base64, 前缀 image_base64 = image_base64.split(',')[1] image_bytes = base64.b64decode(image_base64) pil_image = Image.open(io.BytesIO(image_bytes)) # 确保图片是 RGB 格式 if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # 准备消息格式 messages = [ { "role": "user", "content": [ { "type": "image", "image": pil_image, }, {"type": "text", "text": question}, ], } ] # 处理输入 text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # 生成回答 with torch.no_grad(): # 使用 GenerationConfig 来避免警告并确保参数正确 generation_config = GenerationConfig( max_new_tokens=512, do_sample=False, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True ) generated_ids = model.generate(**inputs, generation_config=generation_config) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return { "success": True, "question": question, "answer": output_text, "image_info": { "size": f"{pil_image.size[0]}x{pil_image.size[1]}", "mode": pil_image.mode } } except Exception as e: logger.error(f"图片分析失败: {str(e)}") return JSONResponse( status_code=500, content={"error": f"图片分析失败: {str(e)}"} ) @app.post("/analyze", response_model=AnalyzeResponse) async def analyze_simple(request: AnalyzeRequest): """ 简化的图片分析接口 (JSON 格式) 接收 JSON 格式的请求,包含 base64 图片和提示词 返回标准化的分析结果 """ logger.info(f"收到图片分析请求: prompt='{request.prompt}', image_length={len(request.image) if request.image else 0}") if not QWEN_VL_AVAILABLE: logger.error(f"Qwen-VL 依赖不可用: {IMPORT_ERROR}") return AnalyzeResponse( success=False, prompt=request.prompt, response="", processing_time=0, error=f"Qwen-VL 依赖不可用: {IMPORT_ERROR}. 请安装: pip install torchvision qwen-vl-utils av opencv-python-headless" ) if model is None or processor is None: logger.error("模型未加载") return AnalyzeResponse( success=False, prompt=request.prompt, response="", processing_time=0, error="模型未加载,请稍后重试" ) start_time = time.time() logger.info("开始处理图片分析请求...") try: # 检查内存状态 memory = psutil.virtual_memory() logger.info(f"当前内存使用: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") if memory.percent > 90: logger.warning("内存使用率过高,可能影响生成性能") # 处理 base64 图片 logger.info("开始处理 base64 图片数据...") image_data = request.image if image_data.startswith('data:image'): # 移除 data:image/xxx;base64, 前缀 image_data = image_data.split(',')[1] logger.info("移除了 data URL 前缀") logger.info(f"解码 base64 数据,长度: {len(image_data)}") image_bytes = base64.b64decode(image_data) logger.info(f"解码后字节数: {len(image_bytes)}") pil_image = Image.open(io.BytesIO(image_bytes)) logger.info(f"图片加载成功: {pil_image.size}, 模式: {pil_image.mode}") # 如果图片太大,进行缩放 max_size = 800 if max(pil_image.size) > max_size: ratio = max_size / max(pil_image.size) new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio)) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) logger.info(f"图片已缩放至: {pil_image.size}") # 确保图片是 RGB 格式 if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') logger.info("图片已转换为 RGB 模式") # 准备消息格式 logger.info("准备模型输入消息...") messages = [ { "role": "user", "content": [ { "type": "image", "image": pil_image, }, {"type": "text", "text": request.prompt}, ], } ] # 处理输入 logger.info("应用聊天模板...") text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) logger.info("处理视觉信息...") image_inputs, video_inputs = process_vision_info(messages) logger.info("处理器编码输入...") inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) logger.info(f"输入处理完成,input_ids shape: {inputs.input_ids.shape if hasattr(inputs, 'input_ids') else 'N/A'}") # 生成回答 logger.info("开始模型生成...") with torch.no_grad(): # 使用最简单的生成参数,避免复杂配置 logger.info("使用简化的生成参数...") generated_ids = model.generate( **inputs, max_new_tokens=256, # 减少生成长度 do_sample=False, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id ) logger.info(f"生成完成,输出 shape: {generated_ids.shape}") logger.info("开始解码生成的文本...") generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] logger.info(f"修剪后的 token 数量: {[len(ids) for ids in generated_ids_trimmed]}") output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # 清理内存 del generated_ids, generated_ids_trimmed, inputs if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() processing_time = time.time() - start_time logger.info(f"分析完成,处理时间: {processing_time:.2f}秒") logger.info(f"生成的文本长度: {len(output_text)} 字符") logger.info(f"生成的文本预览: {output_text[:100]}...") # 检查最终内存状态 final_memory = psutil.virtual_memory() logger.info(f"分析后内存使用: {final_memory.percent:.1f}%") return AnalyzeResponse( success=True, prompt=request.prompt, response=output_text, processing_time=processing_time, image_info={ "size": f"{pil_image.size[0]}x{pil_image.size[1]}", "mode": pil_image.mode, "format": pil_image.format or "Unknown" } ) except Exception as e: processing_time = time.time() - start_time logger.error(f"图片分析失败: {str(e)}") return AnalyzeResponse( success=False, prompt=request.prompt, response="", processing_time=processing_time, error=f"图片分析失败: {str(e)}" ) @app.get("/memory_status") def get_memory_status(): """获取当前内存使用状态""" try: # 系统内存信息 memory = psutil.virtual_memory() # PyTorch 内存信息(如果使用 CUDA) torch_memory = {} if torch.cuda.is_available(): torch_memory = { "cuda_allocated": torch.cuda.memory_allocated() / 1024**3, # GB "cuda_reserved": torch.cuda.memory_reserved() / 1024**3, # GB "cuda_max_allocated": torch.cuda.max_memory_allocated() / 1024**3, # GB } # 进程内存信息 process = psutil.Process() process_memory = process.memory_info() return { "system_memory": { "total_gb": round(memory.total / 1024**3, 2), "available_gb": round(memory.available / 1024**3, 2), "used_gb": round(memory.used / 1024**3, 2), "free_gb": round(memory.free / 1024**3, 2), "percent": round(memory.percent, 1), "buffers_gb": round(getattr(memory, 'buffers', 0) / 1024**3, 2), "cached_gb": round(getattr(memory, 'cached', 0) / 1024**3, 2) }, "process_memory": { "rss_gb": round(process_memory.rss / 1024**3, 2), # 实际物理内存 "vms_gb": round(process_memory.vms / 1024**3, 2), # 虚拟内存 "percent": round(process.memory_percent(), 2) }, "torch_memory": torch_memory, "model_status": { "model_loaded": model is not None, "processor_loaded": processor is not None, "qwen_vl_available": QWEN_VL_AVAILABLE }, "memory_analysis": { "total_physical_memory": f"{memory.total / 1024**3:.2f} GB", "model_estimated_usage": "~8-10 GB (Qwen2-VL-2B)", "available_for_inference": f"{memory.available / 1024**3:.2f} GB", "memory_pressure": "High" if memory.percent > 90 else "Medium" if memory.percent > 75 else "Low", "huggingface_spaces_limit": "可能受到容器内存限制" }, "recommendations": { "memory_usage_ok": memory.percent < 85, "available_for_inference": memory.available / 1024**3 > 2.0, "should_clear_cache": memory.percent > 90, "can_process_images": memory.available > 2 * 1024**3, "suggested_actions": [ "清理缓存 (/clear_cache)" if memory.percent > 85 else None, "重启服务" if memory.percent > 95 else None, "减小图片尺寸" if memory.percent > 80 else None ] } } except Exception as e: return JSONResponse( status_code=500, content={"error": f"获取内存状态失败: {str(e)}"} ) @app.post("/clear_cache") def clear_cache(): """强制清理内存缓存""" try: logger.info("开始强制清理内存缓存...") # 获取清理前的内存状态 memory_before = psutil.virtual_memory() logger.info(f"清理前内存使用: {memory_before.percent:.1f}%") # 多次垃圾回收 total_collected = 0 for i in range(3): collected = gc.collect() total_collected += collected logger.info(f"垃圾回收第{i+1}次: 清理了 {collected} 个对象") # 清理 PyTorch 缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("已清理 CUDA 缓存") # 获取清理后的内存状态 memory_after = psutil.virtual_memory() logger.info(f"清理后内存使用: {memory_after.percent:.1f}%") freed_gb = (memory_after.available - memory_before.available) / 1024**3 return { "success": True, "message": "强制缓存清理完成", "details": { "objects_collected": total_collected, "memory_before_percent": round(memory_before.percent, 1), "memory_after_percent": round(memory_after.percent, 1), "freed_memory_gb": round(freed_gb, 2), "improvement": f"释放了 {freed_gb:.2f} GB 内存" if freed_gb > 0 else "内存使用无明显变化" } } except Exception as e: logger.error(f"缓存清理失败: {str(e)}") return JSONResponse( status_code=500, content={"error": f"缓存清理失败: {str(e)}"} ) @app.get("/web") def web_interface(): """返回 Web 界面""" return FileResponse("static/index.html") @app.post("/analyze-simple") async def analyze_simple_test(request: AnalyzeRequest): """ 简化的测试分析接口 - 用于调试 """ logger.info(f"收到简化测试请求: prompt='{request.prompt}', image_length={len(request.image) if request.image else 0}") if not QWEN_VL_AVAILABLE: logger.error("Qwen-VL 不可用") return AnalyzeResponse( success=False, prompt=request.prompt, response="", processing_time=0, error="Qwen-VL 依赖不可用" ) if model is None or processor is None: logger.error("模型未加载") return AnalyzeResponse( success=False, prompt=request.prompt, response="", processing_time=0, error="模型未加载" ) # 返回模拟结果用于测试 return AnalyzeResponse( success=True, prompt=request.prompt, response="这是一个测试响应。模型已加载并可以接收请求,但为了避免卡住,这里返回模拟结果。", processing_time=1.0, image_info={ "size": "测试", "mode": "RGB", "format": "PNG" } )