Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import clip | |
| import numpy as np | |
| import random | |
| import os | |
| from PIL import Image | |
| from ultralytics import YOLO # Needed for both person and fashion detection | |
| from gtts import gTTS | |
| import uuid | |
| import time | |
| import tempfile | |
| def analyze_outfit(input_img, yolo_person_model, yolo_fashion_model, clip_model, clip_preprocess, | |
| all_prompts, style_prompts_end_index, FASHION_CLASSES, CATEGORY_LABEL_MAP, | |
| response_templates, YOLO_PERSON_CONF_THRESHOLD, YOLO_FASHION_CONF_THRESHOLD, | |
| YOLO_FASHION_HIGH_CONF_THRESHOLD, DEVICE): | |
| # Handle both file paths and PIL Images | |
| if isinstance(input_img, str): | |
| try: | |
| input_img = Image.open(input_img) | |
| except Exception as e: | |
| return (f"<p style='color: #FF5555;'>Error loading image: {str(e)}</p>", | |
| None, "Image loading error") | |
| # Existing code continues... | |
| if input_img is None: | |
| return ("<p style='color: #FF5555; text-align: center;'>Please upload an image.</p>", | |
| None, "Error: No image provided.") | |
| img = input_img.convert("RGB").copy() | |
| #def analyze_outfit(image): | |
| #if image is None: | |
| #return ("<p style='color: #FF5555; text-align: center;'>Please upload an image.</p>", None, "Error: No image provided.") | |
| #image = image.convert("RGB").copy() | |
| #print(f"[DEBUG] image_path type: {type(image_path)} | value: {image_path}") | |
| # 1) YOLO Person Detection | |
| person_results = yolo_person_model(img, verbose=False, conf=YOLO_PERSON_CONF_THRESHOLD) | |
| boxes = person_results[0].boxes.xyxy.cpu().numpy() | |
| classes = person_results[0].boxes.cls.cpu().numpy() | |
| confidences = person_results[0].boxes.conf.cpu().numpy() | |
| # Filter for persons (class 0 in standard YOLOv8) | |
| person_indices = np.where(classes == 0)[0] | |
| cropped_img = img # Default to full image if no person found | |
| person_detected = False | |
| if len(person_indices) > 0: | |
| # Find the person detection with the highest confidence | |
| max_conf_person_idx = person_indices[np.argmax(confidences[person_indices])] | |
| x1, y1, x2, y2 = map(int, boxes[max_conf_person_idx]) | |
| # Ensure coordinates are valid and within image bounds | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(img.width, x2), min(img.height, y2) | |
| if x1 < x2 and y1 < y2: # Check if the box has valid dimensions | |
| cropped_img = img.crop((x1, y1, x2, y2)) | |
| print(f"Person detected and cropped: Box {x1, y1, x2, y2}") | |
| person_detected = True | |
| else: | |
| print("Warning: Invalid person bounding box after clipping. Using full image.") | |
| cropped_img = img | |
| else: | |
| print("No person detected by yolo_person_model. Analyzing full image.") | |
| # 2) YOLO Fashion Model Detection (run on the cropped image if person was found) | |
| detected_fashion_item_name = None | |
| detected_fashion_item_conf = 0.0 | |
| if person_detected or True: # Or always run on the (potentially full) image? Let's always run for now. | |
| try: | |
| fashion_results = yolo_fashion_model(cropped_img, verbose=False, conf=YOLO_FASHION_CONF_THRESHOLD) | |
| fashion_boxes = fashion_results[0].boxes.xyxy.cpu().numpy() | |
| fashion_classes = fashion_results[0].boxes.cls.cpu().numpy().astype(int) | |
| fashion_confidences = fashion_results[0].boxes.conf.cpu().numpy() | |
| if len(fashion_classes) > 0: | |
| # Find the detection with the highest confidence | |
| best_fashion_idx = np.argmax(fashion_confidences) | |
| detected_class_id = fashion_classes[best_fashion_idx] | |
| detected_fashion_item_conf = fashion_confidences[best_fashion_idx] | |
| if detected_class_id in FASHION_CLASSES: | |
| detected_fashion_item_name = FASHION_CLASSES[detected_class_id] | |
| print(f"Fashion model detected: '{detected_fashion_item_name}' " | |
| f"with confidence {detected_fashion_item_conf:.2f}") | |
| else: | |
| print(f"Warning: Detected fashion class ID {detected_class_id} not in FASHION_CLASSES map.") | |
| else: | |
| print("No fashion items detected above threshold by yolo_fashion_model.") | |
| except Exception as e: | |
| print(f"Error during YOLO fashion model analysis: {e}") | |
| # Continue without fashion model input | |
| # 3) CLIP Analysis (always run on the cropped/full image) | |
| clip_detected_item = "look" # Default fallback item name | |
| clip_detected_item_prob = 0.0 | |
| category_key = 'mid' # Default category | |
| final_score_str = "N/A" | |
| try: | |
| image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE) | |
| text_tokens = clip.tokenize(all_prompts).to(DEVICE) | |
| with torch.no_grad(): | |
| logits, _ = clip_model(image_tensor, text_tokens) | |
| all_probs = logits.softmax(dim=-1).cpu().numpy()[0] | |
| # Calculate style scores | |
| drip_len = len(style_prompts['drippy']) | |
| mid_len = len(style_prompts['mid']) | |
| drip_score = np.mean(all_probs[0 : drip_len]) | |
| mid_score = np.mean(all_probs[drip_len : drip_len + mid_len]) | |
| not_score = np.mean(all_probs[drip_len + mid_len : style_prompts_end_index]) | |
| # Determine overall style category AND DEFINE score_label | |
| score_label = "Style Score" # Initialize with a default/fallback | |
| if drip_score > 0.41 and drip_score > mid_score and drip_score > not_score: | |
| category_key = 'drippy' | |
| final_score = drip_score | |
| score_label = "Drip Score" # <<< DEFINE score_label | |
| elif mid_score > not_score: # Check mid_score > not_score explicitly | |
| category_key = 'mid' | |
| final_score = mid_score | |
| score_label = "Mid Score" # <<< DEFINE score_label | |
| else: | |
| category_key = 'not_drippy' | |
| final_score = not_score | |
| score_label = "Trash Score" # <<< DEFINE score_label # Or maybe "Rating Score" | |
| category_label = CATEGORY_LABEL_MAP[category_key] | |
| # final_score_str = f"{final_score:.2f}" # You might not need this raw score string anymore | |
| percentage_score = max(0, final_score * 100) | |
| percentage_score_str = f"{percentage_score:.0f}%" # Formats as integer (e.g., "3%", "15%", "0%") | |
| # Now score_label is defined before being used here | |
| print(f"Style analysis: Category={category_label}, Score = {score_label}={percentage_score_str} (Raw Score: {final_score:.4f})") | |
| # Get top clothing item from CLIP | |
| top_3_clip_items = get_top_clip_clothing(all_probs, n=3) # <<< Ask for top 3 items | |
| if top_3_clip_items: | |
| # Print the top 3 detected items | |
| detected_items_str = ", ".join([f"{item[0]} ({item[1]*100:.1f}%)" for item in top_3_clip_items]) # Show item and probability | |
| print(f"I think I detected: {detected_items_str}") | |
| # Still use the single *most* probable item for response generation logic later | |
| clip_detected_item, clip_detected_item_prob = top_3_clip_items[0] | |
| # Optional: You can keep or remove the print for the single top item below if the top-3 print is sufficient | |
| # print(f"Top clothing item identified by CLIP (for response): '{clip_detected_item}' " | |
| # f"with probability {clip_detected_item_prob:.2f}") | |
| else: | |
| print("I couldn't confidently identify specific clothing items via CLIP.") | |
| clip_detected_item = "piece" # Use a different fallback if CLIP fails | |
| clip_detected_item_prob = 0.0 # Ensure prob is defined | |
| except Exception as e: | |
| print(f"Error during CLIP analysis: {e}") | |
| # Use defaults, maybe return error message? | |
| return ("<p style='color: #FF5555;'>Error during CLIP analysis.</p>", | |
| None, f"Analysis Error: {e}") | |
| # 4) Determine the Final Item to Mention in Response | |
| final_clothing_item = "style" # Ultimate fallback generic term | |
| generic_response_needed = False | |
| if detected_fashion_item_name and detected_fashion_item_conf >= YOLO_FASHION_HIGH_CONF_THRESHOLD: | |
| # Priority 1: High-confidence fashion model detection | |
| final_clothing_item = detected_fashion_item_name | |
| print(f"Using highly confident fashion model item: '{final_clothing_item}'") | |
| elif detected_fashion_item_name and detected_fashion_item_conf >= YOLO_FASHION_CONF_THRESHOLD: | |
| # Priority 2: Medium-confidence fashion model detection (still prefer over CLIP) | |
| final_clothing_item = detected_fashion_item_name | |
| print(f"Using medium confidence fashion model item: '{final_clothing_item}'") | |
| elif clip_detected_item and clip_detected_item_prob > 0.05: # Check if CLIP prob is somewhat reasonable | |
| # Priority 3: CLIP detection (if fashion model didn't provide a strong candidate) | |
| final_clothing_item = clip_detected_item | |
| print(f"Using CLIP detected item: '{final_clothing_item}'") | |
| else: | |
| # Priority 4: Generic response needed (no confident detection from either model) | |
| final_clothing_item = random.choice(["fit", "look", "style", "vibe"]) # Randomize generic term | |
| generic_response_needed = True | |
| print(f"Using generic fallback item: '{final_clothing_item}'") | |
| # 5) Generate Response and TTS | |
| try: | |
| response_pool = response_templates[category_key] | |
| # Choose a random template from the entire response pool | |
| chosen_template = random.choice(response_pool) | |
| # Format the response, substituting the item name if needed | |
| response_text = chosen_template.format(item=final_clothing_item) if '{item}' in chosen_template else chosen_template | |
| tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3") | |
| tts = gTTS(text=response_text, lang='en', tld='com', slow=False) | |
| tts.save(tts_path) | |
| print(f"Generated TTS response: '{response_text}' saved to {tts_path}") | |
| # --- Updated HTML Output --- | |
| category_html = f""" | |
| <div class='results-container'> | |
| <h2 class='result-category'>RATING: {category_label.upper()}</h2> | |
| <p class='result-score'>{score_label}: {percentage_score_str}</p> | |
| </div> | |
| """ | |
| return category_html, tts_path, response_text | |
| except Exception as e: | |
| print(f"Error during response/TTS generation: {e}") | |
| percentage_score = max(0, final_score * 100) | |
| percentage_score_str = f"{percentage_score:.0f}%" | |
| category_html = f""" | |
| <div class='results-container'> | |
| <h2 class='result-category'>Result: {category_label.upper()}</h2> | |
| <p class='result-score'>{score_label}: {percentage_score_str}</p> | |
| <p class='result-error' style='color: #FFAAAA; font-size: 0.9em;'>Error generating audio/full response.</p> | |
| </div> | |
| """ | |
| # Still provide category info, but indicate TTS/response error | |
| return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response." |