| import os |
| import torch |
| import io |
| import shutil |
| from fastapi import FastAPI, File, UploadFile |
| from transformers import AutoProcessor, AutoModelForCausalLM |
| from ultralytics import YOLO |
| from PIL import Image |
| import uvicorn |
|
|
| |
| app = FastAPI(title="YOLO + GIT Captioning API") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| MY_MODEL_PATH = 'best.pt' |
|
|
| print("🔄 جاري تحميل الموديلات... يرجى الانتظار") |
|
|
| |
| try: |
| detection_model = YOLO(MY_MODEL_PATH) |
| print("✅ تم تحميل موديل YOLO الخاص بك بنجاح") |
| except Exception as e: |
| print(f"⚠️ فشل تحميل best.pt، سيتم استخدام الموديل الافتراضي: {e}") |
| detection_model = YOLO("yolov8n.pt") |
|
|
| |
| processor = AutoProcessor.from_pretrained("microsoft/git-base") |
| caption_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to(device) |
|
|
| @app.get("/") |
| def home(): |
| return {"status": "Online", "instruction": "Add /docs to the URL to test the model"} |
|
|
| |
|
|
| @app.post("/analyze") |
| async def analyze_image(file: UploadFile = File(...)): |
| |
| data = await file.read() |
| original_image = Image.open(io.BytesIO(data)).convert("RGB") |
|
|
| |
| results = detection_model(original_image, conf=0.20) |
| integrated_results = [] |
|
|
| for r in results: |
| for i, box in enumerate(r.boxes): |
| label = r.names[int(box.cls)] |
| conf_score = float(box.conf[0]) |
| coords = box.xyxy[0].tolist() |
|
|
| |
| cropped_img = original_image.crop((coords[0], coords[1], coords[2], coords[3])) |
|
|
| |
| inputs = processor(images=cropped_img, return_tensors="pt").to(device) |
| generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=40) |
| detailed_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
| integrated_results.append({ |
| "object_id": i + 1, |
| "label": label, |
| "confidence": f"{conf_score:.2f}", |
| "description": detailed_desc |
| }) |
|
|
| |
| if not integrated_results: |
| inputs = processor(images=original_image, return_tensors="pt").to(device) |
| generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=40) |
| general_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| return { |
| "message": "No specific objects detected by YOLO. Here is a general description.", |
| "general_description": general_desc |
| } |
|
|
| return { |
| "detected_count": len(integrated_results), |
| "results": integrated_results |
| } |
|
|
| |
| if name == "__main__": |
| |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|