Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import clip | |
| import numpy as np | |
| import random | |
| import os | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| from gtts import gTTS | |
| import uuid | |
| import time | |
| import tempfile | |
| from huggingface_hub import hf_hub_download | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| YOLO_PERSON_MODEL_PATH = hf_hub_download( | |
| repo_id="dzmu/dripai-models", | |
| filename="yolov8n.pt", | |
| token=HF_TOKEN, | |
| repo_type="model" | |
| ) | |
| YOLO_FASHION_MODEL_PATH = hf_hub_download( | |
| repo_id="dzmu/dripai-models", | |
| filename="best.pt", | |
| token=HF_TOKEN, | |
| repo_type="model" | |
| ) | |
| CLIP_MODEL_NAME = "ViT-B/32" | |
| # Confidence Thresholds | |
| YOLO_PERSON_CONF_THRESHOLD = 0.4 | |
| YOLO_FASHION_CONF_THRESHOLD = 0.4 | |
| YOLO_FASHION_HIGH_CONF_THRESHOLD = 0.6 | |
| FASHION_CLASSES = { | |
| 0: 'long sleeve top', 1: 'skirt', 2: 'trousers', 3: 'short sleeve top', | |
| 4: 'long sleeve outwear', 5: 'short sleeve dress', 6: 'shorts', | |
| 7: 'vest dress', 8: 'sling dress', 9: 'vest', 10: 'long sleeve dress', | |
| 11: 'sling', 12: 'short sleeve outwear' | |
| } | |
| print(f"Defined {len(FASHION_CLASSES)} fashion categories for {YOLO_FASHION_MODEL_PATH}") | |
| print(f"Using device: {DEVICE}") | |
| try: | |
| clip_model, clip_preprocess = clip.load(CLIP_MODEL_NAME, device=DEVICE) | |
| print(f"CLIP model ({CLIP_MODEL_NAME}) loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading CLIP model: {e}") | |
| try: | |
| yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH) | |
| print(f"YOLO person detection model ({YOLO_PERSON_MODEL_PATH}) loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading YOLO person model: {e}") | |
| try: | |
| yolo_fashion_model = YOLO(YOLO_FASHION_MODEL_PATH) # No .to(DEVICE) needed here | |
| print(f"YOLO fashion detection model ({YOLO_FASHION_MODEL_PATH}) loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading YOLO fashion model: {e}") | |
| style_prompts = { | |
| 'drippy': [ | |
| "High-fashion runway look with designer labels", | |
| "Coordinated color palette with premium fabrics", | |
| "Trend-forward streetwear with luxury accessories", | |
| "Tailored silhouettes with intentional layering", | |
| "Seasonal trend elements executed flawlessly", | |
| "Cohesive outfit with statement pieces", | |
| "Celebrity red carpet-level styling", | |
| "Bold pattern mixing that works harmoniously" | |
| "Well-put-together casual outfit", | |
| "Confident personal style", | |
| "Good color coordination", | |
| "Flattering fit and proportions", | |
| "Appropriate for the occasion", | |
| "Clean and well-maintained clothing", | |
| "Interesting texture combinations", | |
| "Balanced silhouette", | |
| "Thoughtful accessory choices", | |
| "Modern but not overly trendy" | |
| ], | |
| 'mid': [ | |
| "Basic wardrobe staples without flair", | |
| "Safe color combinations with minimal accessories", | |
| "Mass-market fast fashion pieces", | |
| "Functional over fashionable aesthetic", | |
| "Trend-adjacent but poorly executed", | |
| "Mismatched proportions in clothing", | |
| "Overly casual for the occasion", | |
| "Decade-outdated trend revival" | |
| ], | |
| 'not_drippy': [ | |
| "Severely mismatched color schemes", | |
| "Pilled/stained/faded clothing items", | |
| "Ill-fitting garments in multiple pieces", | |
| "Clashing patterns without cohesion", | |
| "Inappropriate footwear for outfit context", | |
| "Overly literal costume-like styling", | |
| "Multiple competing trends in one outfit", | |
| "Wrinkled/unkempt fabric presentation" | |
| ] | |
| } | |
| style_prompts['drippy'] += [ | |
| "coordinated color scheme", | |
| "perfect garment proportions", | |
| "current fashion trends", | |
| "high-quality materials", | |
| "complementary accessories" | |
| ] | |
| style_prompts['not_drippy'] += [ | |
| "clashing colors", | |
| "poor fit proportions", | |
| "outdated trends", | |
| "cheap-looking fabrics", | |
| "missing accessories" | |
| ] | |
| clothing_prompts = [ | |
| "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat", | |
| "dress", "skirt", "pants", "jeans", "trousers", "shorts", | |
| "sneakers", "boots", "heels", "sandals", | |
| "cap", "hat", "scarf", "gloves", "bag", "accessory", "tank-top", "haircut" | |
| ] | |
| all_prompts = [] | |
| for cat_prompts in style_prompts.values(): | |
| all_prompts.extend(cat_prompts) | |
| style_prompts_end_index = len(all_prompts) | |
| all_prompts.extend(clothing_prompts) | |
| response_templates = { | |
| 'drippy': [ | |
| "You're Drippy, bruh – fire {item}!", | |
| "{item} goes crazy, on god!", | |
| "Certified drippy with that {item}.", | |
| "Your {item} just walked a Paris runway.", | |
| "That {item}? Straight from the future.", | |
| "You just turned a sidewalk into a runway." | |
| ], | |
| 'mid': [ | |
| "Drop the {item} and you might get a text back.", | |
| "It's alright, but I'd upgrade the {item}.", | |
| "Mid fit alert. {item} is holding you back.", | |
| "We can do better come on now", | |
| "I don't think you want it enough", | |
| "You're teetering on drip... fix the {item}.", | |
| "You're in the gray zone. That {item} ain't helping." | |
| ], | |
| 'not_drippy': [ | |
| "Bro thought that {item} was tuff!", | |
| "Oh hell nah! Burn that {item}!", | |
| "Crimes against fashion, especially that {item}! Also… maybe get a haircut.", | |
| "Never walk out the house again with that {item}.", | |
| "Your drip is trash, try again.", | |
| "This ain't it chief. The overall style needs work.", | |
| "Didn't need an AI to tell you to go back to the wardrobe", | |
| "Someone call the fashion police. That {item} needs arresting.", | |
| "Your outfit just gave me a 404 error.", | |
| "That {item} made my GPU overheat in shame." | |
| ] | |
| } | |
| CATEGORY_LABEL_MAP = { "drippy": "drippy", "mid": "mid", "not_drippy": "trash" } | |
| def format_detected_items(item_list): | |
| if not item_list: | |
| return "" | |
| return "<p class='result-items'>Detected items: " + ", ".join(item_list) + "</p>" | |
| def get_top_clip_clothing(probs, n=1): | |
| """Gets the top N clothing items based on CLIP probabilities.""" | |
| clothing_probs_start_index = style_prompts_end_index | |
| clothing_probs = probs[clothing_probs_start_index:] | |
| actual_n = min(n, len(clothing_prompts)) | |
| if actual_n <= 0: | |
| return [] | |
| top_indices_in_slice = np.argsort(clothing_probs)[-actual_n:] | |
| top_global_indices = [idx + clothing_probs_start_index for idx in top_indices_in_slice] | |
| top_items_with_probs = [ | |
| (clothing_prompts[i], clothing_probs[i]) | |
| for i in reversed(top_indices_in_slice) | |
| ] | |
| return top_items_with_probs | |
| def wrapped_analyze(input_img): | |
| return analyze_outfit( | |
| input_img, | |
| yolo_person_model, | |
| yolo_fashion_model, | |
| clip_model, | |
| clip_preprocess, | |
| all_prompts, | |
| style_prompts_end_index, | |
| FASHION_CLASSES, | |
| CATEGORY_LABEL_MAP, | |
| response_templates, | |
| YOLO_PERSON_CONF_THRESHOLD, | |
| YOLO_FASHION_CONF_THRESHOLD, | |
| YOLO_FASHION_HIGH_CONF_THRESHOLD, | |
| DEVICE | |
| ) | |
| def analyze_outfit(input_img): | |
| if isinstance(input_img, str): | |
| try: | |
| input_img = Image.open(input_img) | |
| except Exception as e: | |
| return (f"<p style='color: #FF5555;'>Error loading image: {str(e)}</p>", | |
| None, "Image loading error") | |
| if input_img is None: | |
| return ("<p style='color: #FF5555; text-align: center;'>Please upload an image.</p>", | |
| None, "Error: No image provided.") | |
| img = input_img.convert("RGB").copy() | |
| person_results = yolo_person_model(img, verbose=False, conf=YOLO_PERSON_CONF_THRESHOLD) | |
| boxes = person_results[0].boxes.xyxy.cpu().numpy() | |
| classes = person_results[0].boxes.cls.cpu().numpy() | |
| confidences = person_results[0].boxes.conf.cpu().numpy() | |
| person_indices = np.where(classes == 0)[0] | |
| cropped_img = img | |
| person_detected = False | |
| if len(person_indices) > 0: | |
| max_conf_person_idx = person_indices[np.argmax(confidences[person_indices])] | |
| x1, y1, x2, y2 = map(int, boxes[max_conf_person_idx]) | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(img.width, x2), min(img.height, y2) | |
| if x1 < x2 and y1 < y2: | |
| cropped_img = img.crop((x1, y1, x2, y2)) | |
| print(f"Person detected and cropped: Box {x1, y1, x2, y2}") | |
| person_detected = True | |
| else: | |
| print("Warning: Invalid person bounding box after clipping. Using full image.") | |
| cropped_img = img | |
| else: | |
| print("No person detected by yolo_person_model. Analyzing full image.") | |
| detected_fashion_item_name = None | |
| detected_fashion_item_conf = 0.0 | |
| if person_detected or True: | |
| try: | |
| fashion_results = yolo_fashion_model(cropped_img, verbose=False, conf=YOLO_FASHION_CONF_THRESHOLD) | |
| fashion_boxes = fashion_results[0].boxes.xyxy.cpu().numpy() | |
| fashion_classes = fashion_results[0].boxes.cls.cpu().numpy().astype(int) | |
| fashion_confidences = fashion_results[0].boxes.conf.cpu().numpy() | |
| if len(fashion_classes) > 0: | |
| best_fashion_idx = np.argmax(fashion_confidences) | |
| detected_class_id = fashion_classes[best_fashion_idx] | |
| detected_fashion_item_conf = fashion_confidences[best_fashion_idx] | |
| if detected_class_id in FASHION_CLASSES: | |
| detected_fashion_item_name = FASHION_CLASSES[detected_class_id] | |
| print(f"Fashion model detected: '{detected_fashion_item_name}' " | |
| f"with confidence {detected_fashion_item_conf:.2f}") | |
| else: | |
| print(f"Warning: Detected fashion class ID {detected_class_id} not in FASHION_CLASSES map.") | |
| else: | |
| print("No fashion items detected above threshold by yolo_fashion_model.") | |
| except Exception as e: | |
| print(f"Error during YOLO fashion model analysis: {e}") | |
| # Continue without fashion model input | |
| clip_detected_item = "look" # Default fallback item name | |
| clip_detected_item_prob = 0.0 | |
| category_key = 'mid' # Default category | |
| final_score_str = "N/A" | |
| final_score = 0.0 | |
| try: | |
| image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE) | |
| text_tokens = clip.tokenize(all_prompts).to(DEVICE) | |
| with torch.no_grad(): | |
| logits, _ = clip_model(image_tensor, text_tokens) | |
| all_probs = logits.softmax(dim=-1).cpu().numpy()[0] | |
| drip_len = len(style_prompts['drippy']) | |
| mid_len = len(style_prompts['mid']) | |
| score_label = "Style Score" # Initialize with a default/fallback | |
| category_label = CATEGORY_LABEL_MAP[category_key] | |
| percentage_score = max(0, final_score * 100) | |
| percentage_score_str = f"{percentage_score:.0f}%" | |
| # Calculate average probabilities for each category | |
| drip_score = np.mean(all_probs[0 : drip_len]) | |
| mid_score = np.mean(all_probs[drip_len : drip_len + mid_len]) | |
| not_score = np.mean(all_probs[drip_len + mid_len : style_prompts_end_index]) | |
| raw_weighted_score = (drip_score * 1) + (mid_score * 0.5) + (not_score * 0.1) | |
| final_score = raw_weighted_score * 100 + 10 | |
| final_score = min(max(final_score, 0), 100) | |
| if final_score >= 50: | |
| category_key = 'drippy' | |
| score_label = "Drip Score" | |
| elif final_score >= 20: | |
| category_key = 'mid' | |
| score_label = "Mid Score" | |
| else: | |
| category_key = 'not_drippy' | |
| score_label = "Trash Score" | |
| category_label = CATEGORY_LABEL_MAP[category_key] | |
| percentage_score_str = f"{final_score:.0f}%" | |
| print(f"Style analysis: Category={category_label}, Score = {score_label}={percentage_score_str} (Raw Score: {final_score:.4f})") | |
| # Get top clothing item from CLIP | |
| top_3_clip_items = get_top_clip_clothing(all_probs, n=3) | |
| if top_3_clip_items: | |
| detected_items_str = ", ".join([f"{item[0]} ({item[1]*100:.1f}%)" for item in top_3_clip_items]) # Show item and probability | |
| print(f"I think I detected: {detected_items_str}") | |
| clip_detected_item, clip_detected_item_prob = top_3_clip_items[0] | |
| else: | |
| print("I couldn't confidently identify specific clothing items via CLIP.") | |
| clip_detected_item = "piece" # Use a different fallback if CLIP fails | |
| clip_detected_item_prob = 0.0 # Ensure prob is defined | |
| except Exception as e: | |
| print(f"Error during CLIP analysis: {e}") | |
| return ("<p style='color: #FF5555;'>Error during CLIP analysis.</p>", | |
| None, f"Analysis Error: {e}") | |
| # Determine the Final Item to Mention in Response | |
| final_clothing_item = "style" # Ultimate fallback generic term | |
| generic_response_needed = False | |
| if detected_fashion_item_name and detected_fashion_item_conf >= YOLO_FASHION_HIGH_CONF_THRESHOLD: | |
| # Priority 1: High-confidence fashion model detection | |
| final_clothing_item = detected_fashion_item_name | |
| print(f"Using highly confident fashion model item: '{final_clothing_item}'") | |
| elif detected_fashion_item_name and detected_fashion_item_conf >= YOLO_FASHION_CONF_THRESHOLD: | |
| # Priority 2: Medium-confidence fashion model detection (still prefer over CLIP) | |
| final_clothing_item = detected_fashion_item_name | |
| print(f"Using medium confidence fashion model item: '{final_clothing_item}'") | |
| elif clip_detected_item and clip_detected_item_prob > 0.05: # Check if CLIP prob is somewhat reasonable | |
| # Priority 3: CLIP detection (if fashion model didn't provide a strong candidate) | |
| final_clothing_item = clip_detected_item | |
| print(f"Using CLIP detected item: '{final_clothing_item}'") | |
| else: | |
| # Priority 4: Generic response needed (no confident detection from either model) | |
| final_clothing_item = random.choice(["fit", "look", "style", "vibe"]) # Randomize generic term | |
| generic_response_needed = True | |
| print(f"Using generic fallback item: '{final_clothing_item}'") | |
| try: | |
| response_pool = response_templates[category_key] | |
| chosen_template = random.choice(response_pool) | |
| # Format the response, substituting the item name if needed | |
| response_text = chosen_template.format(item=final_clothing_item) if '{item}' in chosen_template else chosen_template | |
| tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3") | |
| tts = gTTS(text=response_text, lang='en', tld='com', slow=False) | |
| tts.save(tts_path) | |
| print(f"Generated TTS response: '{response_text}' saved to {tts_path}") | |
| # --- Updated HTML Output --- | |
| category_html = f""" | |
| <div class='results-container'> | |
| <h2 class='result-category'>RATING: {category_label.upper()}</h2> | |
| <p class='result-score'>Overall Fit Rating: {percentage_score_str}/100%</p> | |
| <p class='result-confidence'>Detected Item: {final_clothing_item.title()}</p> | |
| </div> | |
| """ | |
| return category_html, tts_path, response_text | |
| except Exception as e: | |
| print(f"Error during response/TTS generation: {e}") | |
| percentage_score = max(0, final_score * 100) | |
| percentage_score_str = f"{percentage_score:.0f}%" | |
| category_html = f""" | |
| <div class='results-container'> | |
| <h2 class='result-category'>Result: {category_label.upper()}</h2> | |
| <p class='result-score'>Overall Fit Rating: {percentage_score_str}/100</p> | |
| <p class='result-confidence'>Detected Item: {final_clothing_item.title()}</p> | |
| <p class='result-error' style='color: #FFAAAA; font-size: 0.9em;'>Error generating audio/full response.</p> | |
| </div> | |
| """ | |
| # Still provide category info, but indicate TTS/response error | |
| return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response." | |
| custom_css = """:root { | |
| --primary-bg-color: #0D0D0D; | |
| --secondary-bg-color: #1A1A1A; | |
| --text-color: #FFFFFF; | |
| --accent-color: #7F5AF0; /* Electric purple */ | |
| --accent-hover: #C084FC; | |
| --success-color: #2CB67D; | |
| --error-color: #F25F4C; | |
| --border-color: #2F2F2F; | |
| --input-bg-color: #1A1A1A; | |
| --button-text-color: #FFFFFF; | |
| --body-text-size: 16px; | |
| } | |
| body, | |
| .gradio-container { | |
| background-color: var(--primary-bg-color) !important; | |
| color: var(--text-color) !important; | |
| font-family: 'Inter', 'Segoe UI', sans-serif; | |
| font-size: var(--body-text-size); | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| .gr-block { | |
| background-color: var(--secondary-bg-color) !important; | |
| border: 1px solid var(--border-color) !important; | |
| border-radius: 16px !important; | |
| padding: 20px !important; | |
| box-shadow: 0 0 15px rgba(127, 90, 240, 0.1); | |
| } | |
| .gr-input, | |
| .gr-output, | |
| .gr-textbox textarea, | |
| .gr-dropdown select, | |
| .gr-checkboxgroup input { | |
| background-color: var(--input-bg-color) !important; | |
| color: var(--text-color) !important; | |
| border: 1px solid var(--border-color) !important; | |
| border-radius: 10px !important; | |
| padding: 10px; | |
| } | |
| .gr-textbox textarea::placeholder { | |
| color: #888888 !important; | |
| } | |
| .gr-label span, | |
| .gr-label .label-text { | |
| color: var(--text-color) !important; | |
| font-weight: 500 !important; | |
| font-size: 0.95em !important; | |
| margin-bottom: 8px !important; | |
| } | |
| .gr-image { | |
| background-color: var(--primary-bg-color) !important; | |
| border: 2px dashed var(--border-color) !important; | |
| border-radius: 12px !important; | |
| overflow: hidden; | |
| } | |
| .gr-image img { | |
| border-radius: 10px !important; | |
| object-fit: contain; | |
| } | |
| .gr-image .no-image, | |
| .gr-image .upload-button { | |
| color: #AAAAAA !important; | |
| } | |
| .gr-audio > div:first-of-type { | |
| border: 1px solid var(--border-color) !important; | |
| background-color: var(--secondary-bg-color) !important; | |
| border-radius: 10px !important; | |
| padding: 12px !important; | |
| } | |
| .gr-audio audio { | |
| width: 100%; | |
| filter: invert(1) hue-rotate(180deg); | |
| } | |
| .gr-button { | |
| border: none !important; | |
| border-radius: 10px !important; | |
| transition: background-color 0.2s ease, transform 0.1s ease; | |
| font-weight: 600 !important; | |
| } | |
| .gr-button-primary { | |
| background-color: var(--accent-color) !important; | |
| color: var(--button-text-color) !important; | |
| font-size: 1.1em !important; | |
| padding: 14px 24px !important; | |
| letter-spacing: 0.5px; | |
| text-transform: uppercase; | |
| box-shadow: 0 0 10px rgba(127, 90, 240, 0.2); | |
| } | |
| .gr-button-primary:hover { | |
| background-color: var(--accent-hover) !important; | |
| transform: scale(1.03); | |
| box-shadow: 0 0 15px 3px rgba(192, 132, 252, 0.6); | |
| } | |
| .gr-button-primary:active { | |
| transform: scale(0.97); | |
| } | |
| h1, h2, h3 { | |
| color: var(--text-color) !important; | |
| font-weight: 600; | |
| letter-spacing: 0.5px; | |
| } | |
| .prose h1 { | |
| text-align: center; | |
| margin-bottom: 25px !important; | |
| font-size: 2em !important; | |
| text-transform: uppercase; | |
| letter-spacing: 1.5px; | |
| } | |
| .prose p { | |
| color: #CCCCCC !important; | |
| font-size: 0.95em; | |
| text-align: center; | |
| } | |
| .results-container { | |
| text-align: center; | |
| padding: 20px; | |
| border: 1px solid var(--accent-color); | |
| border-radius: 16px; | |
| background: linear-gradient(145deg, var(--secondary-bg-color), #2a2a2a); | |
| backdrop-filter: blur(10px); | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.37); | |
| } | |
| .result-category { | |
| color: var(--accent-color) !important; | |
| font-size: 1.6em; | |
| margin-bottom: 5px; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| } | |
| .result-score { | |
| color: var(--text-color) !important; | |
| font-size: 1.2em; | |
| margin-bottom: 5px; | |
| } | |
| .result-error { | |
| color: var(--error-color) !important; | |
| font-size: 0.9em; | |
| margin-top: 5px; | |
| } | |
| .gradio-container { | |
| max-width: 850px !important; | |
| margin: auto !important; | |
| padding-top: 30px; | |
| } | |
| .gr-row { | |
| gap: 25px !important; | |
| }""" | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| #type='pil', | |
| label="Upload Your Fit", | |
| sources=['upload', 'webcam', 'clipboard'], | |
| height=400, | |
| show_label=False | |
| ) | |
| analyze_button = gr.Button("🔥 Analyze This Drip", variant="primary") | |
| with gr.Column(scale=1): | |
| category_html = gr.HTML() | |
| response_box = gr.Textbox( | |
| lines=2, | |
| label="Verbal Feedback", | |
| interactive=False, | |
| show_label=False, | |
| placeholder="Verbal feedback will show up here." | |
| ) | |
| audio_output = gr.Audio(label="Audio Feedback", autoplay=True) | |
| analyze_button.click( | |
| fn=analyze_outfit, | |
| inputs=[input_image], | |
| outputs=[category_html, audio_output, response_box] | |
| ) | |
| if __name__ == "__main__": | |
| if not os.path.exists(YOLO_FASHION_MODEL_PATH): | |
| print(f"\n{'='*20} WARNING {'='*20}") | |
| print(f"Fashion model file '{YOLO_FASHION_MODEL_PATH}' not found!") | |
| print(f"The app will run but fashion item detection will be skipped.") | |
| print(f"{'='*50}\n") | |
| demo.launch(debug=False, show_error=True) # Set debug=False for deployment |