ake178178 commited on
Commit
b614546
·
verified ·
1 Parent(s): 3396eae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -1,23 +1,37 @@
1
- import cv2
2
- from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
 
 
3
 
4
- faceCascade = cv2.CascadeClassifier(cv2.haarcascades+'haarcascade_frontalface_default.xml')
 
 
5
 
 
 
 
6
 
7
- class VideoTransformer(VideoTransformerBase):
8
- def __init__(self):
9
- self.i = 0
 
 
 
 
 
10
 
11
- def transform(self, frame):
12
- img = frame.to_ndarray(format="bgr24")
13
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
14
- faces = faceCascade.detectMultiScale(gray, 1.3, 5)
15
- i =self.i+1
16
- for (x, y, w, h) in faces:
17
- cv2.rectangle(img, (x, y), (x + w, y + h), (95, 207, 30), 3)
18
- cv2.rectangle(img, (x, y - 40), (x + w, y), (95, 207, 30), -1)
19
- cv2.putText(img, 'F-' + str(i), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
20
 
21
- return img
 
 
 
 
 
 
 
22
 
23
- webrtc_streamer(key="example", video_transformer_factory=VideoTransformer)
 
 
1
+ import streamlit as st
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image
4
+ import requests
5
 
6
+ # 加载模型和处理器
7
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
9
 
10
+ def load_image(image_file):
11
+ img = Image.open(image_file)
12
+ return img
13
 
14
+ def predict(image):
15
+ # 处理图片
16
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
17
+ # 生成预测
18
+ output_ids = model.generate(pixel_values)
19
+ # 转换输出文本
20
+ output_text = processor.decode(output_ids[0], skip_special_tokens=True)
21
+ return output_text
22
 
23
+ def main():
24
+ st.title("图片物体识别")
25
+ image_file = st.file_uploader("上传一张图片", type=["jpg", "png", "jpeg"])
 
 
 
 
 
 
26
 
27
+ if image_file is not None:
28
+ # 显示图片
29
+ image = load_image(image_file)
30
+ st.image(image, caption='上传的图片', use_column_width=True)
31
+ # 预测图片
32
+ if st.button("识别图片"):
33
+ result_text = predict(image)
34
+ st.write(f"识别结果: {result_text}")
35
 
36
+ if __name__ == '__main__':
37
+ main()