|
|
import os |
|
|
import torch |
|
|
import psutil |
|
|
import gc |
|
|
import time |
|
|
import json |
|
|
import warnings |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_VERBOSITY"] = "error" |
|
|
from fastapi import FastAPI, File, UploadFile, Form, Request |
|
|
from fastapi.responses import JSONResponse, FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import logging |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, GenerationConfig |
|
|
from qwen_vl_utils import process_vision_info |
|
|
QWEN_VL_AVAILABLE = True |
|
|
IMPORT_ERROR = None |
|
|
except ImportError as e: |
|
|
QWEN_VL_AVAILABLE = False |
|
|
IMPORT_ERROR = str(e) |
|
|
|
|
|
class Qwen2VLForConditionalGeneration: |
|
|
pass |
|
|
class AutoTokenizer: |
|
|
pass |
|
|
class AutoProcessor: |
|
|
pass |
|
|
class GenerationConfig: |
|
|
pass |
|
|
def process_vision_info(*args, **kwargs): |
|
|
raise ImportError("qwen_vl_utils not available") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AnalyzeRequest(BaseModel): |
|
|
image: str |
|
|
prompt: str = "请描述这张图片的内容" |
|
|
|
|
|
class AnalyzeResponse(BaseModel): |
|
|
success: bool |
|
|
prompt: str |
|
|
response: str |
|
|
processing_time: float |
|
|
image_info: dict = None |
|
|
error: str = None |
|
|
|
|
|
app = FastAPI(title="Qwen-VL PicExam API", description="基于 Qwen2-VL-2B-Instruct 的图像理解 API") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
tokenizer = None |
|
|
|
|
|
def load_model(): |
|
|
"""加载 Qwen2-VL-2B-Instruct 模型(CPU 版本,适合 16GB 内存)""" |
|
|
global model, processor, tokenizer |
|
|
|
|
|
|
|
|
if not QWEN_VL_AVAILABLE: |
|
|
logger.error(f"Qwen-VL 依赖不可用: {IMPORT_ERROR}") |
|
|
logger.error("请安装缺失的依赖: pip install torchvision qwen-vl-utils") |
|
|
return False |
|
|
|
|
|
try: |
|
|
logger.info("开始加载 Qwen2-VL-2B-Instruct 模型...") |
|
|
memory = psutil.virtual_memory() |
|
|
logger.info(f"模型加载前内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
|
|
|
model_name = "Qwen/Qwen2-VL-2B-Instruct" |
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
logger.info("正在加载 Processor...") |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
memory = psutil.virtual_memory() |
|
|
logger.info(f"Processor 加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
|
|
|
logger.info("正在加载 Tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
memory = psutil.virtual_memory() |
|
|
logger.info(f"Tokenizer 加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("正在加载核心模型... (此步骤最消耗内存)") |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="cpu", |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
|
|
|
use_cache=False, |
|
|
attn_implementation="eager", |
|
|
) |
|
|
memory = psutil.virtual_memory() |
|
|
logger.info(f"核心模型加载后内存: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("模型加载成功!") |
|
|
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"模型加载失败: {str(e)}", exc_info=True) |
|
|
memory = psutil.virtual_memory() |
|
|
logger.error(f"失败时内存状态: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
return False |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""应用启动时加载模型""" |
|
|
success = load_model() |
|
|
if not success: |
|
|
logger.error("模型加载失败,应用可能无法正常工作") |
|
|
|
|
|
@app.get("/") |
|
|
def api_documentation(): |
|
|
"""API 文档和端点说明""" |
|
|
return { |
|
|
"service": "Qwen-VL PicExam API", |
|
|
"description": "基于 Qwen2-VL-2B-Instruct 的图像理解 API,支持 16GB 内存 + CPU 推理", |
|
|
"version": "1.0.0", |
|
|
"model": "Qwen2-VL-2B-Instruct", |
|
|
"status": { |
|
|
"service": "running", |
|
|
"model_loaded": model is not None, |
|
|
"inference_mode": "CPU", |
|
|
"dependencies_available": QWEN_VL_AVAILABLE, |
|
|
"import_error": IMPORT_ERROR if not QWEN_VL_AVAILABLE else None |
|
|
}, |
|
|
"endpoints": { |
|
|
"GET /": { |
|
|
"description": "获取 API 文档和端点信息", |
|
|
"response": "JSON 格式的 API 说明" |
|
|
}, |
|
|
"GET /health": { |
|
|
"description": "健康检查接口", |
|
|
"response": "服务状态信息" |
|
|
}, |
|
|
"GET /web": { |
|
|
"description": "Web 界面", |
|
|
"response": "HTML 页面,提供图形化操作界面" |
|
|
}, |
|
|
"POST /analyze_image": { |
|
|
"description": "分析上传的图片文件", |
|
|
"parameters": { |
|
|
"image": "图片文件 (multipart/form-data)", |
|
|
"question": "关于图片的问题 (可选,默认为描述图片内容)" |
|
|
}, |
|
|
"example": "curl -X POST '/analyze_image' -F 'image=@photo.jpg' -F 'question=这张图片中有什么?'" |
|
|
}, |
|
|
"POST /analyze_image_base64": { |
|
|
"description": "分析 base64 编码的图片", |
|
|
"parameters": { |
|
|
"image_base64": "base64 编码的图片数据", |
|
|
"question": "关于图片的问题 (可选)" |
|
|
}, |
|
|
"example": "curl -X POST '/analyze_image_base64' -F 'image_base64=...' -F 'question=描述这张图片'" |
|
|
}, |
|
|
"POST /analyze": { |
|
|
"description": "简化的图片分析接口 (JSON 格式)", |
|
|
"parameters": { |
|
|
"image": "base64 编码的图片数据", |
|
|
"prompt": "提示词/问题" |
|
|
}, |
|
|
"example": "curl -X POST '/analyze' -H 'Content-Type: application/json' -d '{\"image\":\"data:image/jpeg;base64,...\",\"prompt\":\"描述图片\"}'" |
|
|
}, |
|
|
"GET /memory_status": { |
|
|
"description": "获取内存使用状态", |
|
|
"response": "系统内存和模型内存使用情况" |
|
|
}, |
|
|
"POST /clear_cache": { |
|
|
"description": "清理内存缓存", |
|
|
"response": "缓存清理结果" |
|
|
} |
|
|
}, |
|
|
"usage_examples": { |
|
|
"curl_file_upload": "curl -X POST 'http://localhost:7860/analyze_image' -F 'image=@your_image.jpg' -F 'question=请描述这张图片'", |
|
|
"curl_base64": "curl -X POST 'http://localhost:7860/analyze_image_base64' -F 'image_base64=...' -F 'question=这张图片中有什么?'", |
|
|
"curl_json": "curl -X POST 'http://localhost:7860/analyze' -H 'Content-Type: application/json' -d '{\"image\":\"...\",\"prompt\":\"请详细描述这张图片的内容\"}'" |
|
|
}, |
|
|
"supported_formats": ["JPEG", "PNG", "WebP", "BMP", "GIF"], |
|
|
"memory_requirements": "16GB RAM recommended for optimal performance", |
|
|
"inference_time": "5-15 seconds per image (depends on CPU)", |
|
|
"documentation": "Visit /docs for interactive API documentation" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
"""简单的健康检查接口""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"service": "Qwen-VL PicExam API", |
|
|
"model_loaded": model is not None, |
|
|
"dependencies_available": QWEN_VL_AVAILABLE, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
|
|
|
@app.get("/dependencies") |
|
|
def check_dependencies(): |
|
|
"""检查依赖状态""" |
|
|
|
|
|
basic_dependencies = { |
|
|
"fastapi": False, |
|
|
"uvicorn": False, |
|
|
"multipart": False, |
|
|
"PIL": False, |
|
|
"pydantic": False |
|
|
} |
|
|
|
|
|
|
|
|
qwen_dependencies = { |
|
|
"torch": False, |
|
|
"torchvision": False, |
|
|
"transformers": False, |
|
|
"qwen_vl_utils": False |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
import fastapi |
|
|
basic_dependencies["fastapi"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import uvicorn |
|
|
basic_dependencies["uvicorn"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import multipart |
|
|
basic_dependencies["multipart"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from PIL import Image |
|
|
basic_dependencies["PIL"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import pydantic |
|
|
basic_dependencies["pydantic"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
qwen_dependencies["torch"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import torchvision |
|
|
qwen_dependencies["torchvision"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import transformers |
|
|
qwen_dependencies["transformers"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import qwen_vl_utils |
|
|
qwen_dependencies["qwen_vl_utils"] = True |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
missing_basic = [dep for dep, available in basic_dependencies.items() if not available] |
|
|
missing_qwen = [dep for dep, available in qwen_dependencies.items() if not available] |
|
|
|
|
|
return { |
|
|
"basic_dependencies": basic_dependencies, |
|
|
"qwen_dependencies": qwen_dependencies, |
|
|
"basic_ready": len(missing_basic) == 0, |
|
|
"qwen_ready": len(missing_qwen) == 0, |
|
|
"missing_basic": missing_basic, |
|
|
"missing_qwen": missing_qwen, |
|
|
"installation_commands": { |
|
|
"basic": f"pip install {' '.join(missing_basic)}" if missing_basic else "All basic dependencies installed", |
|
|
"qwen": f"pip install {' '.join(missing_qwen)}" if missing_qwen else "All Qwen-VL dependencies installed" |
|
|
}, |
|
|
"qwen_vl_available": QWEN_VL_AVAILABLE, |
|
|
"import_error": IMPORT_ERROR if not QWEN_VL_AVAILABLE else None, |
|
|
"note": "基础依赖用于 API 运行,Qwen-VL 依赖用于实际的图像分析功能" |
|
|
} |
|
|
|
|
|
@app.post("/analyze_image") |
|
|
async def analyze_image( |
|
|
image: UploadFile = File(...), |
|
|
question: str = Form("请描述这张图片的内容") |
|
|
): |
|
|
""" |
|
|
分析上传的图片并回答问题 |
|
|
|
|
|
Args: |
|
|
image: 上传的图片文件 |
|
|
question: 关于图片的问题(默认为描述图片内容) |
|
|
|
|
|
Returns: |
|
|
JSON 响应包含分析结果 |
|
|
""" |
|
|
if not QWEN_VL_AVAILABLE: |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={ |
|
|
"error": "Qwen-VL 依赖不可用", |
|
|
"details": IMPORT_ERROR, |
|
|
"solution": "请安装缺失的依赖: pip install torchvision qwen-vl-utils av opencv-python-headless" |
|
|
} |
|
|
) |
|
|
|
|
|
if model is None or processor is None: |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={"error": "模型未加载,请稍后重试"} |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
image_bytes = await image.read() |
|
|
pil_image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": pil_image, |
|
|
}, |
|
|
{"type": "text", "text": question}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
generation_config = GenerationConfig( |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
generated_ids = model.generate(**inputs, generation_config=generation_config) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"question": question, |
|
|
"answer": output_text, |
|
|
"image_info": { |
|
|
"filename": image.filename, |
|
|
"size": f"{pil_image.size[0]}x{pil_image.size[1]}", |
|
|
"mode": pil_image.mode |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"图片分析失败: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": f"图片分析失败: {str(e)}"} |
|
|
) |
|
|
|
|
|
@app.post("/analyze_image_base64") |
|
|
async def analyze_image_base64( |
|
|
image_base64: str = Form(...), |
|
|
question: str = Form("请描述这张图片的内容") |
|
|
): |
|
|
""" |
|
|
分析 base64 编码的图片并回答问题 |
|
|
|
|
|
Args: |
|
|
image_base64: base64 编码的图片数据 |
|
|
question: 关于图片的问题 |
|
|
|
|
|
Returns: |
|
|
JSON 响应包含分析结果 |
|
|
""" |
|
|
if not QWEN_VL_AVAILABLE: |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={ |
|
|
"error": "Qwen-VL 依赖不可用", |
|
|
"details": IMPORT_ERROR, |
|
|
"solution": "请安装缺失的依赖: pip install torchvision qwen-vl-utils av opencv-python-headless" |
|
|
} |
|
|
) |
|
|
|
|
|
if model is None or processor is None: |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={"error": "模型未加载,请稍后重试"} |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
if image_base64.startswith('data:image'): |
|
|
|
|
|
image_base64 = image_base64.split(',')[1] |
|
|
|
|
|
image_bytes = base64.b64decode(image_base64) |
|
|
pil_image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": pil_image, |
|
|
}, |
|
|
{"type": "text", "text": question}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
generation_config = GenerationConfig( |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
generated_ids = model.generate(**inputs, generation_config=generation_config) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"question": question, |
|
|
"answer": output_text, |
|
|
"image_info": { |
|
|
"size": f"{pil_image.size[0]}x{pil_image.size[1]}", |
|
|
"mode": pil_image.mode |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"图片分析失败: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": f"图片分析失败: {str(e)}"} |
|
|
) |
|
|
|
|
|
@app.post("/analyze", response_model=AnalyzeResponse) |
|
|
async def analyze_simple(request: AnalyzeRequest): |
|
|
""" |
|
|
简化的图片分析接口 (JSON 格式) |
|
|
|
|
|
接收 JSON 格式的请求,包含 base64 图片和提示词 |
|
|
返回标准化的分析结果 |
|
|
""" |
|
|
logger.info(f"收到图片分析请求: prompt='{request.prompt}', image_length={len(request.image) if request.image else 0}") |
|
|
|
|
|
if not QWEN_VL_AVAILABLE: |
|
|
logger.error(f"Qwen-VL 依赖不可用: {IMPORT_ERROR}") |
|
|
return AnalyzeResponse( |
|
|
success=False, |
|
|
prompt=request.prompt, |
|
|
response="", |
|
|
processing_time=0, |
|
|
error=f"Qwen-VL 依赖不可用: {IMPORT_ERROR}. 请安装: pip install torchvision qwen-vl-utils av opencv-python-headless" |
|
|
) |
|
|
|
|
|
if model is None or processor is None: |
|
|
logger.error("模型未加载") |
|
|
return AnalyzeResponse( |
|
|
success=False, |
|
|
prompt=request.prompt, |
|
|
response="", |
|
|
processing_time=0, |
|
|
error="模型未加载,请稍后重试" |
|
|
) |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info("开始处理图片分析请求...") |
|
|
|
|
|
try: |
|
|
|
|
|
memory = psutil.virtual_memory() |
|
|
logger.info(f"当前内存使用: {memory.percent:.1f}%, 可用: {memory.available / 1024**3:.2f}GB") |
|
|
|
|
|
if memory.percent > 90: |
|
|
logger.warning("内存使用率过高,可能影响生成性能") |
|
|
|
|
|
|
|
|
logger.info("开始处理 base64 图片数据...") |
|
|
image_data = request.image |
|
|
if image_data.startswith('data:image'): |
|
|
|
|
|
image_data = image_data.split(',')[1] |
|
|
logger.info("移除了 data URL 前缀") |
|
|
|
|
|
logger.info(f"解码 base64 数据,长度: {len(image_data)}") |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
logger.info(f"解码后字节数: {len(image_bytes)}") |
|
|
|
|
|
pil_image = Image.open(io.BytesIO(image_bytes)) |
|
|
logger.info(f"图片加载成功: {pil_image.size}, 模式: {pil_image.mode}") |
|
|
|
|
|
|
|
|
max_size = 800 |
|
|
if max(pil_image.size) > max_size: |
|
|
ratio = max_size / max(pil_image.size) |
|
|
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio)) |
|
|
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
logger.info(f"图片已缩放至: {pil_image.size}") |
|
|
|
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
logger.info("图片已转换为 RGB 模式") |
|
|
|
|
|
|
|
|
logger.info("准备模型输入消息...") |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": pil_image, |
|
|
}, |
|
|
{"type": "text", "text": request.prompt}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
logger.info("应用聊天模板...") |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
logger.info("处理视觉信息...") |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
logger.info("处理器编码输入...") |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
logger.info(f"输入处理完成,input_ids shape: {inputs.input_ids.shape if hasattr(inputs, 'input_ids') else 'N/A'}") |
|
|
|
|
|
|
|
|
logger.info("开始模型生成...") |
|
|
with torch.no_grad(): |
|
|
|
|
|
logger.info("使用简化的生成参数...") |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
eos_token_id=processor.tokenizer.eos_token_id |
|
|
) |
|
|
logger.info(f"生成完成,输出 shape: {generated_ids.shape}") |
|
|
|
|
|
logger.info("开始解码生成的文本...") |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
logger.info(f"修剪后的 token 数量: {[len(ids) for ids in generated_ids_trimmed]}") |
|
|
|
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
|
|
|
del generated_ids, generated_ids_trimmed, inputs |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
logger.info(f"分析完成,处理时间: {processing_time:.2f}秒") |
|
|
logger.info(f"生成的文本长度: {len(output_text)} 字符") |
|
|
logger.info(f"生成的文本预览: {output_text[:100]}...") |
|
|
|
|
|
|
|
|
final_memory = psutil.virtual_memory() |
|
|
logger.info(f"分析后内存使用: {final_memory.percent:.1f}%") |
|
|
|
|
|
return AnalyzeResponse( |
|
|
success=True, |
|
|
prompt=request.prompt, |
|
|
response=output_text, |
|
|
processing_time=processing_time, |
|
|
image_info={ |
|
|
"size": f"{pil_image.size[0]}x{pil_image.size[1]}", |
|
|
"mode": pil_image.mode, |
|
|
"format": pil_image.format or "Unknown" |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
processing_time = time.time() - start_time |
|
|
logger.error(f"图片分析失败: {str(e)}") |
|
|
|
|
|
return AnalyzeResponse( |
|
|
success=False, |
|
|
prompt=request.prompt, |
|
|
response="", |
|
|
processing_time=processing_time, |
|
|
error=f"图片分析失败: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.get("/memory_status") |
|
|
def get_memory_status(): |
|
|
"""获取当前内存使用状态""" |
|
|
try: |
|
|
|
|
|
memory = psutil.virtual_memory() |
|
|
|
|
|
|
|
|
torch_memory = {} |
|
|
if torch.cuda.is_available(): |
|
|
torch_memory = { |
|
|
"cuda_allocated": torch.cuda.memory_allocated() / 1024**3, |
|
|
"cuda_reserved": torch.cuda.memory_reserved() / 1024**3, |
|
|
"cuda_max_allocated": torch.cuda.max_memory_allocated() / 1024**3, |
|
|
} |
|
|
|
|
|
|
|
|
process = psutil.Process() |
|
|
process_memory = process.memory_info() |
|
|
|
|
|
return { |
|
|
"system_memory": { |
|
|
"total_gb": round(memory.total / 1024**3, 2), |
|
|
"available_gb": round(memory.available / 1024**3, 2), |
|
|
"used_gb": round(memory.used / 1024**3, 2), |
|
|
"free_gb": round(memory.free / 1024**3, 2), |
|
|
"percent": round(memory.percent, 1), |
|
|
"buffers_gb": round(getattr(memory, 'buffers', 0) / 1024**3, 2), |
|
|
"cached_gb": round(getattr(memory, 'cached', 0) / 1024**3, 2) |
|
|
}, |
|
|
"process_memory": { |
|
|
"rss_gb": round(process_memory.rss / 1024**3, 2), |
|
|
"vms_gb": round(process_memory.vms / 1024**3, 2), |
|
|
"percent": round(process.memory_percent(), 2) |
|
|
}, |
|
|
"torch_memory": torch_memory, |
|
|
"model_status": { |
|
|
"model_loaded": model is not None, |
|
|
"processor_loaded": processor is not None, |
|
|
"qwen_vl_available": QWEN_VL_AVAILABLE |
|
|
}, |
|
|
"memory_analysis": { |
|
|
"total_physical_memory": f"{memory.total / 1024**3:.2f} GB", |
|
|
"model_estimated_usage": "~8-10 GB (Qwen2-VL-2B)", |
|
|
"available_for_inference": f"{memory.available / 1024**3:.2f} GB", |
|
|
"memory_pressure": "High" if memory.percent > 90 else "Medium" if memory.percent > 75 else "Low", |
|
|
"huggingface_spaces_limit": "可能受到容器内存限制" |
|
|
}, |
|
|
"recommendations": { |
|
|
"memory_usage_ok": memory.percent < 85, |
|
|
"available_for_inference": memory.available / 1024**3 > 2.0, |
|
|
"should_clear_cache": memory.percent > 90, |
|
|
"can_process_images": memory.available > 2 * 1024**3, |
|
|
"suggested_actions": [ |
|
|
"清理缓存 (/clear_cache)" if memory.percent > 85 else None, |
|
|
"重启服务" if memory.percent > 95 else None, |
|
|
"减小图片尺寸" if memory.percent > 80 else None |
|
|
] |
|
|
} |
|
|
} |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": f"获取内存状态失败: {str(e)}"} |
|
|
) |
|
|
|
|
|
@app.post("/clear_cache") |
|
|
def clear_cache(): |
|
|
"""强制清理内存缓存""" |
|
|
try: |
|
|
logger.info("开始强制清理内存缓存...") |
|
|
|
|
|
|
|
|
memory_before = psutil.virtual_memory() |
|
|
logger.info(f"清理前内存使用: {memory_before.percent:.1f}%") |
|
|
|
|
|
|
|
|
total_collected = 0 |
|
|
for i in range(3): |
|
|
collected = gc.collect() |
|
|
total_collected += collected |
|
|
logger.info(f"垃圾回收第{i+1}次: 清理了 {collected} 个对象") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
logger.info("已清理 CUDA 缓存") |
|
|
|
|
|
|
|
|
memory_after = psutil.virtual_memory() |
|
|
logger.info(f"清理后内存使用: {memory_after.percent:.1f}%") |
|
|
|
|
|
freed_gb = (memory_after.available - memory_before.available) / 1024**3 |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": "强制缓存清理完成", |
|
|
"details": { |
|
|
"objects_collected": total_collected, |
|
|
"memory_before_percent": round(memory_before.percent, 1), |
|
|
"memory_after_percent": round(memory_after.percent, 1), |
|
|
"freed_memory_gb": round(freed_gb, 2), |
|
|
"improvement": f"释放了 {freed_gb:.2f} GB 内存" if freed_gb > 0 else "内存使用无明显变化" |
|
|
} |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"缓存清理失败: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": f"缓存清理失败: {str(e)}"} |
|
|
) |
|
|
|
|
|
@app.get("/web") |
|
|
def web_interface(): |
|
|
"""返回 Web 界面""" |
|
|
return FileResponse("static/index.html") |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/analyze-simple") |
|
|
async def analyze_simple_test(request: AnalyzeRequest): |
|
|
""" |
|
|
简化的测试分析接口 - 用于调试 |
|
|
""" |
|
|
logger.info(f"收到简化测试请求: prompt='{request.prompt}', image_length={len(request.image) if request.image else 0}") |
|
|
|
|
|
if not QWEN_VL_AVAILABLE: |
|
|
logger.error("Qwen-VL 不可用") |
|
|
return AnalyzeResponse( |
|
|
success=False, |
|
|
prompt=request.prompt, |
|
|
response="", |
|
|
processing_time=0, |
|
|
error="Qwen-VL 依赖不可用" |
|
|
) |
|
|
|
|
|
if model is None or processor is None: |
|
|
logger.error("模型未加载") |
|
|
return AnalyzeResponse( |
|
|
success=False, |
|
|
prompt=request.prompt, |
|
|
response="", |
|
|
processing_time=0, |
|
|
error="模型未加载" |
|
|
) |
|
|
|
|
|
|
|
|
return AnalyzeResponse( |
|
|
success=True, |
|
|
prompt=request.prompt, |
|
|
response="这是一个测试响应。模型已加载并可以接收请求,但为了避免卡住,这里返回模拟结果。", |
|
|
processing_time=1.0, |
|
|
image_info={ |
|
|
"size": "测试", |
|
|
"mode": "RGB", |
|
|
"format": "PNG" |
|
|
} |
|
|
) |
|
|
|