Final_App / app.py
ek-5's picture
Update app.py
46da2d1 verified
raw
history blame
3.38 kB
import os
import torch
import io
import shutil
from fastapi import FastAPI, File, UploadFile
from transformers import AutoProcessor, AutoModelForCausalLM
from ultralytics import YOLO
from PIL import Image
import uvicorn
# --- 1. إعداد التطبيق والموديلات ---
app = FastAPI(title="YOLO + GIT Captioning API")
# تحديد الجهاز (استخدام CPU للمساحات المجانية لضمان الاستقرار)
device = "cuda" if torch.cuda.is_available() else "cpu"
# مسار الموديل الذي رفعتِيه يدوياً في القائمة
MY_MODEL_PATH = 'best.pt'
print("🔄 جاري تحميل الموديلات... يرجى الانتظار")
# تحميل موديل YOLO الخاص بكِ
try:
detection_model = YOLO(MY_MODEL_PATH)
print("✅ تم تحميل موديل YOLO الخاص بك بنجاح")
except Exception as e:
print(f"⚠️ فشل تحميل best.pt، سيتم استخدام الموديل الافتراضي: {e}")
detection_model = YOLO("yolov8n.pt")
# تحميل موديل GIT-base (أخف وأسرع للمساحة المجانية)
processor = AutoProcessor.from_pretrained("microsoft/git-base")
caption_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to(device)
@app.get("/")
def home():
return {"status": "Online", "instruction": "Add /docs to the URL to test the model"}
# --- 2. وظيفة المعالجة ---
@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
data = await file.read()
original_image = Image.open(io.BytesIO(data)).convert("RGB")
# 1. الكشف باستخدام YOLO
results = detection_model(original_image, conf=0.20)
integrated_results = []
for r in results:
for i, box in enumerate(r.boxes):
label = r.names[int(box.cls)]
conf_score = float(box.conf[0])
coords = box.xyxy[0].tolist()
# 2. عملية القص (Cropping)
cropped_img = original_image.crop((coords[0], coords[1], coords[2], coords[3]))
# 3. وصف الجزء المقصوص عبر موديل GIT
inputs = processor(images=cropped_img, return_tensors="pt").to(device)
generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=40)
detailed_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
integrated_results.append({
"object_id": i + 1,
"label": label,
"confidence": f"{conf_score:.2f}",
"description": detailed_desc
})
if not integrated_results:
inputs = processor(images=original_image, return_tensors="pt").to(device)
generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=40)
general_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {
"message": "No specific objects detected. General description provided.",
"general_description": general_desc
}
return {
"detected_count": len(integrated_results),
"results": integrated_results
}
# --- 3. تشغيل السيرفر (تم تصحيح الشرطات السفلية هنا) ---
if name == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)