CatMood / app.py
catninja123's picture
Upload app.py with huggingface_hub
b9a3426 verified
"""
CatMood - 猫咪情绪识别与多模态健康数据众筹平台
Phase 1: 面部情绪识别(立即可用)
Phase 2-4: 体态数据、行为日志、叫声录音(众筹解锁)
"""
import gradio as gr
import torch
import json
import os
import csv
import uuid
import shutil
from datetime import datetime
from pathlib import Path
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
# ============================================================
# 配置
# ============================================================
MODEL_ID = "semihdervis/cat-emotion-classifier"
DATA_DIR = Path("collected_data")
DATA_DIR.mkdir(exist_ok=True)
(DATA_DIR / "photos").mkdir(exist_ok=True)
(DATA_DIR / "audio").mkdir(exist_ok=True)
(DATA_DIR / "body_metrics").mkdir(exist_ok=True)
(DATA_DIR / "behavior_logs").mkdir(exist_ok=True)
# 众筹阈值
THRESHOLDS = {
"photos": 5000,
"body_metrics": 1000,
"behavior_logs": 500,
"audio": 2000,
"multimodal": 300,
}
# 情绪标签中英文映射
EMOTION_CN = {
"Angry": "愤怒",
"Disgusted": "厌恶",
"Happy": "开心",
"Normal": "平静",
"Sad": "悲伤",
"Scared": "害怕",
"Surprised": "惊讶",
}
EMOTION_EMOJI = {
"Angry": "😾",
"Disgusted": "🙀",
"Happy": "😸",
"Normal": "😺",
"Sad": "😿",
"Scared": "🙀",
"Surprised": "😲",
}
# ============================================================
# 模型加载
# ============================================================
print("Loading CatMood model...")
model = ViTForImageClassification.from_pretrained(MODEL_ID)
processor = ViTImageProcessor.from_pretrained(MODEL_ID)
model.eval()
print(f"Model loaded: {MODEL_ID} ({model.config.num_labels} emotions)")
# ============================================================
# 数据计数
# ============================================================
def count_data(category):
"""统计某类数据的数量"""
csv_path = DATA_DIR / f"{category}_log.csv"
if not csv_path.exists():
return 0
with open(csv_path, "r") as f:
return max(0, sum(1 for _ in f) - 1)
def get_progress_html():
"""生成数据众筹进度条HTML"""
categories = [
("photos", "面部照片", "用于训练增强版情绪识别模型(7类 → 更精准)"),
("body_metrics", "体态数据", "解锁 FBMI 体脂计算 + 健康风险评估"),
("behavior_logs", "行为日志", "解锁行为异常预警功能"),
("audio", "叫声录音", "解锁叫声情绪分析"),
("multimodal", "多模态对齐数据", "解锁三合一综合健康评估"),
]
html = '<div style="padding: 20px; font-family: system-ui, -apple-system, sans-serif;">'
html += '<h2 style="text-align: center; color: #1a1a2e; margin-bottom: 8px;">Data Crowdsourcing Progress</h2>'
html += '<p style="text-align: center; color: #666; font-size: 14px; margin-bottom: 24px;">每一份数据都在推动猫咪健康 AI 的进步</p>'
for cat, name, desc in categories:
current = count_data(cat)
threshold = THRESHOLDS[cat]
pct = min(100, (current / threshold) * 100)
if pct >= 100:
status = "UNLOCKED"
bar_color = "#10b981"
status_color = "#10b981"
elif pct >= 50:
status = f"{pct:.0f}%"
bar_color = "#f59e0b"
status_color = "#f59e0b"
else:
status = f"{pct:.0f}%"
bar_color = "#6366f1"
status_color = "#6366f1"
html += f'''
<div style="margin-bottom: 20px; background: #f8fafc; border-radius: 12px; padding: 16px;">
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 6px;">
<span style="font-weight: 600; color: #1a1a2e; font-size: 15px;">{name}</span>
<span style="font-weight: 700; color: {status_color}; font-size: 14px;">{status}</span>
</div>
<div style="font-size: 12px; color: #888; margin-bottom: 8px;">{desc}</div>
<div style="background: #e2e8f0; border-radius: 999px; height: 10px; overflow: hidden;">
<div style="background: {bar_color}; height: 100%; width: {pct}%; border-radius: 999px; transition: width 0.5s;"></div>
</div>
<div style="font-size: 12px; color: #aaa; margin-top: 4px; text-align: right;">{current} / {threshold}</div>
</div>
'''
html += '</div>'
return html
# ============================================================
# 核心功能:情绪识别
# ============================================================
def predict_emotion(image):
"""预测猫咪面部情绪"""
if image is None:
return "Please upload a cat face photo.", None
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
results = {}
for idx, prob in enumerate(probs):
label = model.config.id2label[idx]
cn_label = EMOTION_CN.get(label, label)
emoji = EMOTION_EMOJI.get(label, "")
results[f"{emoji} {cn_label} ({label})"] = float(prob)
top_idx = torch.argmax(probs).item()
top_label = model.config.id2label[top_idx]
top_cn = EMOTION_CN.get(top_label, top_label)
top_emoji = EMOTION_EMOJI.get(top_label, "")
top_prob = float(probs[top_idx])
summary = f"## {top_emoji} {top_cn} ({top_label})\n\n"
summary += f"**Confidence: {top_prob:.1%}**\n\n"
advice = {
"Happy": "> Your cat looks happy and content! Great job as a cat parent.",
"Normal": "> Your cat appears calm and relaxed. Everything seems fine.",
"Sad": "> Your cat might be feeling down. Consider some extra playtime or cuddles.",
"Angry": "> Your cat seems irritated. Give them some space and check for stressors.",
"Scared": "> Your cat appears frightened. Try to identify and remove the source of fear.",
"Surprised": "> Your cat looks surprised! Something caught their attention.",
"Disgusted": "> Your cat seems displeased. Check their food, litter, or environment.",
}
summary += advice.get(top_label, "")
return summary, results
# ============================================================
# 数据收集功能
# ============================================================
def save_photo_data(image, user_emotion_label, context_note):
"""保存用户上传的面部照片数据"""
if image is None:
return "Please upload a photo first.", get_progress_html()
file_id = str(uuid.uuid4())[:8]
timestamp = datetime.now().isoformat()
img_path = DATA_DIR / "photos" / f"{file_id}.jpg"
if isinstance(image, Image.Image):
image.save(str(img_path), "JPEG", quality=90)
csv_path = DATA_DIR / "photos_log.csv"
write_header = not csv_path.exists()
with open(csv_path, "a", newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["id", "timestamp", "filename", "user_emotion_label", "context_note"])
writer.writerow([file_id, timestamp, f"{file_id}.jpg", user_emotion_label, context_note])
count = count_data("photos")
return f"Photo saved successfully! (Total: {count} / {THRESHOLDS['photos']})", get_progress_html()
def save_body_metrics(weight, rib_cage, leg_length, age_years, age_months, breed, sex, neutered):
"""保存体态数据并计算FBMI"""
if not weight or not rib_cage or not leg_length:
return "Please fill in weight, rib cage circumference, and leg length.", "", get_progress_html()
try:
w = float(weight)
rc = float(rib_cage)
ll = float(leg_length)
except ValueError:
return "Please enter valid numbers.", "", get_progress_html()
fbmi = ((rc / 0.7062) - ll) / 0.9156 - ll
if fbmi < 15:
fbmi_status = "Underweight (偏瘦)"
fbmi_advice = "Your cat may be underweight. Consider consulting a vet about nutrition."
elif fbmi <= 30:
fbmi_status = "Normal (正常)"
fbmi_advice = "Your cat's body fat percentage is in the healthy range."
elif fbmi <= 42:
fbmi_status = "Overweight (超重)"
fbmi_advice = "Your cat is overweight. Consider portion control and more playtime."
else:
fbmi_status = "Obese (肥胖)"
fbmi_advice = "Your cat is obese. Please consult a veterinarian for a weight management plan."
fbmi_result = f"""## FBMI Result: {fbmi:.1f}%
**Status: {fbmi_status}**
{fbmi_advice}
---
*FBMI (Feline Body Mass Index) developed by Waltham Pet Nutrition Centre. Validated against DEXA scanning with <10% error (Witzel et al. 2014).*
"""
file_id = str(uuid.uuid4())[:8]
timestamp = datetime.now().isoformat()
csv_path = DATA_DIR / "body_metrics_log.csv"
write_header = not csv_path.exists()
with open(csv_path, "a", newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["id", "timestamp", "weight_kg", "rib_cage_cm", "leg_length_cm",
"fbmi", "age_years", "age_months", "breed", "sex", "neutered"])
writer.writerow([file_id, timestamp, w, rc, ll, f"{fbmi:.1f}",
age_years or "", age_months or "", breed or "", sex or "", neutered or ""])
count = count_data("body_metrics")
msg = f"Body metrics saved! (Total: {count} / {THRESHOLDS['body_metrics']})"
return msg, fbmi_result, get_progress_html()
def save_behavior_log(respiratory_rate, water_intake, appetite, urination, activity,
grooming, vomiting_freq, additional_notes):
"""保存行为观察日志"""
file_id = str(uuid.uuid4())[:8]
timestamp = datetime.now().isoformat()
warnings = []
if respiratory_rate and float(respiratory_rate) > 30:
warnings.append("Respiratory rate >30/min during sleep may indicate heart disease. (Porciello et al. 2016)")
if water_intake == "Increased":
warnings.append("Increased water intake may indicate kidney disease, diabetes, or hyperthyroidism.")
if appetite == "Increased" and activity == "Decreased":
warnings.append("Increased appetite with decreased activity may suggest hyperthyroidism.")
if appetite == "Decreased":
warnings.append("Decreased appetite may indicate pain, kidney disease, or other conditions.")
if grooming == "Excessive":
warnings.append("Excessive grooming may indicate stress or skin conditions.")
if grooming == "Decreased":
warnings.append("Decreased grooming may indicate pain or arthritis.")
if vomiting_freq and float(vomiting_freq) > 2:
warnings.append("Frequent vomiting (>2x/week) may indicate gastrointestinal issues.")
warning_text = ""
if warnings:
warning_text = "## Health Alerts\n\n"
for w in warnings:
warning_text += f"- {w}\n"
warning_text += "\n> **Disclaimer**: These are preliminary indicators only. Please consult a veterinarian for professional diagnosis.\n"
else:
warning_text = "## No Alerts\n\nAll behavioral indicators appear normal. Keep monitoring regularly!"
csv_path = DATA_DIR / "behavior_logs_log.csv"
write_header = not csv_path.exists()
with open(csv_path, "a", newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["id", "timestamp", "respiratory_rate", "water_intake", "appetite",
"urination", "activity", "grooming", "vomiting_freq", "notes"])
writer.writerow([file_id, timestamp, respiratory_rate or "", water_intake or "",
appetite or "", urination or "", activity or "", grooming or "",
vomiting_freq or "", additional_notes or ""])
count = count_data("behavior_logs")
msg = f"Behavior log saved! (Total: {count} / {THRESHOLDS['behavior_logs']})"
return msg, warning_text, get_progress_html()
def save_audio_data(audio_filepath, context_label):
"""保存叫声录音"""
if audio_filepath is None:
return "Please record or upload audio first.", get_progress_html()
file_id = str(uuid.uuid4())[:8]
timestamp = datetime.now().isoformat()
audio_path = DATA_DIR / "audio" / f"{file_id}.wav"
shutil.copy2(str(audio_filepath), str(audio_path))
csv_path = DATA_DIR / "audio_log.csv"
write_header = not csv_path.exists()
with open(csv_path, "a", newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["id", "timestamp", "filename", "context_label"])
writer.writerow([file_id, timestamp, f"{file_id}.wav", context_label or "Unknown"])
count = count_data("audio")
return f"Audio saved! (Total: {count} / {THRESHOLDS['audio']})", get_progress_html()
# ============================================================
# Gradio 界面
# ============================================================
CUSTOM_CSS = """
.main-title {
text-align: center;
font-size: 2.2em;
font-weight: 800;
background: linear-gradient(135deg, #6366f1, #8b5cf6, #a855f7);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0;
}
.subtitle {
text-align: center;
color: #666;
font-size: 1.1em;
margin-top: 4px;
}
"""
with gr.Blocks(css=CUSTOM_CSS, title="CatMood - AI Cat Emotion & Health", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px 0 10px 0;">
<h1 class="main-title">CatMood</h1>
<p class="subtitle">AI-Powered Cat Emotion Recognition & Health Assessment</p>
<p style="color: #999; font-size: 13px;">Upload a cat face photo to detect emotion | Contribute data to unlock more features</p>
</div>
""")
with gr.Tabs():
# ============ Tab 1: 情绪识别 ============
with gr.Tab("Emotion Recognition", id="emotion"):
gr.Markdown("""
### How it works
Upload a clear photo of your cat's face. Our ViT-based model will analyze facial features
and predict the emotional state. The model recognizes 7 emotions: Happy, Normal, Sad, Angry,
Scared, Surprised, and Disgusted.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Cat Face Photo", height=350)
predict_btn = gr.Button("Analyze Emotion", variant="primary", size="lg")
with gr.Column(scale=1):
emotion_summary = gr.Markdown(label="Result")
emotion_probs = gr.Label(label="Emotion Probabilities", num_top_classes=7)
predict_btn.click(
fn=predict_emotion,
inputs=[input_image],
outputs=[emotion_summary, emotion_probs]
)
# ============ Tab 2: 数据贡献 - 照片 ============
with gr.Tab("Contribute Photos", id="contribute_photos"):
gr.Markdown("""
### Contribute Cat Face Photos
Help us build a better emotion recognition model! Upload photos of your cat with an emotion label.
Your data will be used to train improved models that benefit all cat owners.
**Privacy**: All data is anonymized. No personal information is collected.
""")
with gr.Row():
with gr.Column(scale=1):
contrib_image = gr.Image(type="pil", label="Cat Face Photo", height=300)
contrib_emotion = gr.Dropdown(
choices=["Relaxed", "Happy", "Anxious", "Aggressive", "In Pain", "Unknown"],
label="What emotion do you think your cat is showing?",
value="Unknown"
)
contrib_context = gr.Textbox(
label="Context (optional)",
placeholder="e.g., Just finished eating, Playing with toy, At the vet..."
)
contrib_btn = gr.Button("Submit Photo", variant="primary")
with gr.Column(scale=1):
contrib_msg = gr.Markdown()
contrib_progress = gr.HTML(value=get_progress_html())
contrib_btn.click(
fn=save_photo_data,
inputs=[contrib_image, contrib_emotion, contrib_context],
outputs=[contrib_msg, contrib_progress]
)
# ============ Tab 3: 体态数据 ============
with gr.Tab("Body Metrics", id="body_metrics"):
gr.Markdown("""
### Cat Body Metrics & FBMI Calculator
Measure your cat's body metrics to calculate the **Feline Body Mass Index (FBMI)** -
a scientifically validated method to estimate body fat percentage.
**How to measure:**
1. **Rib cage circumference**: Use a soft tape measure around the chest at the 9th rib (just behind the front legs)
2. **Leg length**: Measure from the middle of the kneecap to the ankle on a back leg
3. **Weight**: Use a home scale (weigh yourself holding the cat, then subtract your weight)
*Based on research by Waltham Pet Nutrition Centre, validated against DEXA scanning (Witzel et al. 2014)*
""")
with gr.Row():
with gr.Column(scale=1):
bm_weight = gr.Number(label="Weight (kg)", precision=2)
bm_rib = gr.Number(label="Rib Cage Circumference (cm)", precision=1)
bm_leg = gr.Number(label="Lower Back Leg Length (cm)", precision=1)
with gr.Row():
bm_age_y = gr.Number(label="Age (years)", precision=0)
bm_age_m = gr.Number(label="Age (months)", precision=0)
bm_breed = gr.Dropdown(
choices=["Unknown/Mixed", "Persian", "Siamese", "Maine Coon", "British Shorthair",
"Ragdoll", "Bengal", "Abyssinian", "Scottish Fold", "Sphynx",
"Russian Blue", "Norwegian Forest", "Birman", "Oriental Shorthair", "Other"],
label="Breed", value="Unknown/Mixed"
)
with gr.Row():
bm_sex = gr.Radio(choices=["Male", "Female"], label="Sex")
bm_neutered = gr.Radio(choices=["Yes", "No", "Unknown"], label="Neutered/Spayed")
bm_btn = gr.Button("Calculate FBMI & Submit", variant="primary")
with gr.Column(scale=1):
bm_msg = gr.Markdown()
bm_result = gr.Markdown()
bm_progress = gr.HTML(value=get_progress_html())
bm_btn.click(
fn=save_body_metrics,
inputs=[bm_weight, bm_rib, bm_leg, bm_age_y, bm_age_m, bm_breed, bm_sex, bm_neutered],
outputs=[bm_msg, bm_result, bm_progress]
)
# ============ Tab 4: 行为日志 ============
with gr.Tab("Behavior Log", id="behavior"):
gr.Markdown("""
### Daily Behavior Observation Log
Record your cat's daily behaviors to help detect early signs of health issues.
Each indicator below has been scientifically linked to specific diseases.
**Tip**: The sleeping respiratory rate is the single most sensitive predictor of heart failure in cats
(Porciello et al. 2016). Count breaths for 15 seconds while your cat sleeps, then multiply by 4.
""")
with gr.Row():
with gr.Column(scale=1):
bl_rr = gr.Number(label="Sleeping Respiratory Rate (breaths/min)",
info="Count for 15 sec x 4. Normal: 15-30/min. >30 = alert",
precision=0)
bl_water = gr.Radio(
choices=["Normal", "Increased", "Decreased"],
label="Water Intake Change",
info="Increased may indicate kidney disease, diabetes, or hyperthyroidism"
)
bl_appetite = gr.Radio(
choices=["Normal", "Increased", "Decreased"],
label="Appetite Change",
info="Increased + weight loss may suggest hyperthyroidism"
)
bl_urination = gr.Radio(
choices=["Normal", "Increased", "Decreased"],
label="Urination Frequency Change",
info="Increased may indicate kidney disease or diabetes"
)
bl_activity = gr.Radio(
choices=["Normal", "Increased", "Decreased"],
label="Activity Level Change",
info="Decreased may indicate pain/arthritis"
)
bl_grooming = gr.Radio(
choices=["Normal", "Excessive", "Decreased"],
label="Grooming Behavior",
info="Excessive may indicate stress/skin disease"
)
bl_vomit = gr.Number(label="Vomiting Frequency (times/week)", precision=0,
info=">2 times/week may indicate GI issues")
bl_notes = gr.Textbox(label="Additional Notes (optional)",
placeholder="Any other observations...")
bl_btn = gr.Button("Submit Behavior Log", variant="primary")
with gr.Column(scale=1):
bl_msg = gr.Markdown()
bl_warnings = gr.Markdown()
bl_progress = gr.HTML(value=get_progress_html())
bl_btn.click(
fn=save_behavior_log,
inputs=[bl_rr, bl_water, bl_appetite, bl_urination, bl_activity,
bl_grooming, bl_vomit, bl_notes],
outputs=[bl_msg, bl_warnings, bl_progress]
)
# ============ Tab 5: 叫声录音 ============
with gr.Tab("Voice Recording", id="voice"):
gr.Markdown("""
### Cat Vocalization Recording
Record your cat's meows to help build a vocalization analysis model.
**Tips for good recordings:**
- Record in a quiet environment
- Get close to your cat (within 1 meter)
- Record for 10-30 seconds
- Note what your cat was doing/wanting at the time
*Research shows cat vocalizations carry information about emotional state and needs
(Scientific American, 2025)*
""")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(label="Record or Upload Cat Meow", type="filepath")
audio_context = gr.Dropdown(
choices=["Hungry", "Seeking Attention", "In Pain/Discomfort",
"Content/Purring", "Greeting", "Angry/Hissing", "Unknown"],
label="What was the context?",
value="Unknown"
)
audio_btn = gr.Button("Submit Recording", variant="primary")
with gr.Column(scale=1):
audio_msg = gr.Markdown()
audio_progress = gr.HTML(value=get_progress_html())
audio_btn.click(
fn=save_audio_data,
inputs=[audio_input, audio_context],
outputs=[audio_msg, audio_progress]
)
# ============ Tab 6: 数据众筹进度 ============
with gr.Tab("Community Progress", id="progress"):
gr.Markdown("""
### Community Data Crowdsourcing Progress
Every piece of data you contribute helps us build better AI models for cat health assessment.
When enough data is collected, new features will be automatically unlocked for all users!
**How it works:**
1. **You contribute data** - photos, body metrics, behavior logs, or voice recordings
2. **We reach the threshold** - automatic model training is triggered
3. **New features unlock** - everyone benefits from improved AI
This is a community effort. Together, we're building the future of AI-powered pet healthcare.
""")
progress_display = gr.HTML(value=get_progress_html())
refresh_btn = gr.Button("Refresh Progress", variant="secondary")
refresh_btn.click(fn=get_progress_html, outputs=[progress_display])
gr.Markdown("""
---
### Scientific Foundation
Our data collection is guided by peer-reviewed veterinary research:
| Data Type | Scientific Basis | Key Reference |
|-----------|-----------------|---------------|
| Facial Photos | Feline Grimace Scale (5 facial action units) | Evangelista et al. 2019, *Scientific Reports* (cited 330x) |
| Body Metrics | FBMI body fat estimation | Witzel et al. 2014, *JAVMA* (cited 64x) |
| Body Condition | BCS-disease association (14 conditions) | Teng et al. 2018, *JSAP* (cited 92x) |
| Respiratory Rate | Heart failure prediction | Porciello et al. 2016, *Vet Journal* (cited 72x) |
| Behavior Changes | Disease early warning | Stelow 2020, *Vet Clinics* |
| Vocalizations | Emotion & need classification | CatMeows dataset; *Scientific American* 2025 |
---
*CatMood is an open research project. All models and anonymized datasets will be released
on Hugging Face for the benefit of the veterinary AI community.*
""")
gr.Markdown("""
---
<div style="text-align: center; color: #999; font-size: 12px; padding: 10px;">
CatMood v0.1 | Powered by ViT + Hugging Face |
<a href="https://huggingface.co/semihdervis/cat-emotion-classifier">Model</a> |
<a href="https://www.felinegrimacescale.com/">Feline Grimace Scale</a>
<br>
<strong>Disclaimer</strong>: This tool is for educational and research purposes only.
It is not a substitute for professional veterinary advice.
</div>
""")
# ============================================================
# 启动
# ============================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)