TangYiJay commited on
Commit
366963e
·
verified ·
1 Parent(s): 04ec2fd
Files changed (1) hide show
  1. app.py +47 -24
app.py CHANGED
@@ -1,40 +1,63 @@
1
- from transformers import AutoProcessor, AutoModelForVision2Seq
2
  from PIL import Image
3
  import gradio as gr
4
  import torch
 
5
 
6
- MODEL_ID = "HuggingFaceM4/idefics2-8b"
7
 
8
- # 强制使用 CPU 模式
9
- device = "cpu"
 
10
 
11
- # 加载模型与处理器(关闭 float16 避免 CPU 报错)
12
- processor = AutoProcessor.from_pretrained(MODEL_ID)
13
- model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=torch.float32, device_map=None)
14
- model.to(device)
15
 
16
- def analyze_images(base_img, target_img, user_prompt):
 
 
 
 
 
 
 
17
  if base_img is None or target_img is None:
18
- return "Please upload both a base image and a target image."
19
-
20
- images = [base_img, target_img]
21
- prompt = f"Ignore the first image (base image). Analyze the second image: {user_prompt}"
22
-
23
- inputs = processor(images=images, text=prompt, return_tensors="pt").to(device)
24
- output = model.generate(**inputs, max_new_tokens=200)
25
- result = processor.decode(output[0], skip_special_tokens=True)
26
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  demo = gr.Interface(
29
- fn=analyze_images,
30
  inputs=[
31
  gr.Image(type="pil", label="Base Image"),
32
- gr.Image(type="pil", label="Target Image"),
33
- gr.Textbox(label="Prompt", placeholder="Describe what to analyze...")
34
  ],
35
- outputs=gr.Textbox(label="Model Output"),
36
- title="Image Comparison (IDEFICS2-8B, CPU Mode)",
37
- description="Upload two images. The model will ignore the base image and analyze the second according to your prompt."
38
  )
39
 
40
  if __name__ == "__main__":
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
  from PIL import Image
3
  import gradio as gr
4
  import torch
5
+ import numpy as np
6
 
7
+ MODEL_ID = "openai/clip-vit-base-patch32"
8
 
9
+ # Load model & processor
10
+ model = CLIPModel.from_pretrained(MODEL_ID)
11
+ processor = CLIPProcessor.from_pretrained(MODEL_ID)
12
 
13
+ # Candidate material labels
14
+ LABELS = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
 
 
15
 
16
+ def get_image_embedding(image):
17
+ inputs = processor(images=image, return_tensors="pt")
18
+ with torch.no_grad():
19
+ embedding = model.get_image_features(**inputs)
20
+ embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
21
+ return embedding.cpu().numpy()
22
+
23
+ def classify_material(base_img, target_img):
24
  if base_img is None or target_img is None:
25
+ return "Please upload both base and target images."
26
+
27
+ # Compute embeddings
28
+ base_emb = get_image_embedding(base_img)
29
+ target_emb = get_image_embedding(target_img)
30
+
31
+ # Difference score
32
+ diff = np.linalg.norm(target_emb - base_emb)
33
+
34
+ # Text embeddings for all labels
35
+ text_inputs = processor(text=LABELS, return_tensors="pt", padding=True)
36
+ with torch.no_grad():
37
+ text_emb = model.get_text_features(**text_inputs)
38
+ text_emb = text_emb / text_emb.norm(p=2, dim=-1, keepdim=True)
39
+
40
+ # Compute similarity with target image
41
+ img_inputs = processor(images=target_img, return_tensors="pt")
42
+ with torch.no_grad():
43
+ img_feat = model.get_image_features(**img_inputs)
44
+ img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
45
+
46
+ sims = torch.matmul(img_feat, text_emb.T).squeeze(0)
47
+ best_idx = torch.argmax(sims).item()
48
+ best_label = LABELS[best_idx]
49
+
50
+ return f"Detected material: {best_label}\nDifference from base: {diff:.4f}"
51
 
52
  demo = gr.Interface(
53
+ fn=classify_material,
54
  inputs=[
55
  gr.Image(type="pil", label="Base Image"),
56
+ gr.Image(type="pil", label="Target Image")
 
57
  ],
58
+ outputs=gr.Textbox(label="Detection Result"),
59
+ title="Material Classification (CLIP, CPU Mode)",
60
+ description="Upload a base image (background) and a target image (with object). The model detects what new material appears: plastic, metal, paper, cardboard, glass, or trash."
61
  )
62
 
63
  if __name__ == "__main__":