Pafkun333 commited on
Commit
aeaf3f3
·
1 Parent(s): 3acf7f6

Commiting first one

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model_2.pth filter=lfs diff=lfs merge=lfs -text
37
+ examples/bleyla_new.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/byjd_new.jpg filter=lfs diff=lfs merge=lfs -text
39
+ examples/falafelcho.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from gtts import gTTS
6
+ import os
7
+ import uuid
8
+ import random
9
+ import time
10
+
11
+ from model import load_face_classifier_model # Import the model loading function
12
+
13
+ # Define the same validation transform used during training
14
+ val_transform = transforms.Compose([
15
+ transforms.Resize(256),
16
+ transforms.CenterCrop(224),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ # Load the model using the function from model.py
22
+ model = load_face_classifier_model(model_path='model_2.pth', num_classes=5)
23
+
24
+
25
+ def cleanup_audio_files(directory=".", prefix="prediction_", max_age_seconds=30):
26
+ now = time.time()
27
+ for filename in os.listdir(directory):
28
+ if filename.startswith(prefix) and filename.endswith(".mp3"):
29
+ filepath = os.path.join(directory, filename)
30
+ file_age = now - os.path.getmtime(filepath)
31
+ if file_age > max_age_seconds:
32
+ try:
33
+ os.remove(filepath)
34
+ except Exception as e:
35
+ print(f"Error deleting {filename}: {e}")
36
+
37
+ def classify_face_with_audio_new(image: Image.Image):
38
+ """
39
+ Classifies a single image (captured from camera) using a trained model
40
+ and generates an audio file of the prediction.
41
+
42
+ Args:
43
+ image (PIL.Image.Image): The input image.
44
+
45
+ Returns:
46
+ tuple: A tuple containing the predicted class name (str)
47
+ and the path to the generated audio file (str).
48
+ """
49
+
50
+ byjd_audio = ["Не ме гледай! Дай ми пауч!", "Писи Писи, Мяу Мяу", "просто мяу",
51
+ "мррррррррррр"]
52
+ bleyla_audio = ["Плешкиииииитуууууууууууу", "Дай ми цун!", "Отивам при Вес Божа",
53
+ "А къде е прасетуу ?"]
54
+ jenny_audio = ["Офффф гладна съм!", "Здравейте, аз съм в овулация.", "Да пием кафе на 43.12 и да ядем шницел!",
55
+ "Офф бе Павееел!", "Обичам Дони Донсъна."]
56
+ sachu_audio = ["Мишо, ще ти счупя носа!", "Засъхнало аку на дупи на кучии.", "Чекии ли си правиш бе, педалче малко?",
57
+ "Обичам пръцкото на Сога!"]
58
+ falafel_audio = ["Дааарлинг, къде са ми чорапите?", "Маняк, измий си краката.", "Молим те, изкъпи се!",
59
+ "Обичам пръцкото на Жени!"]
60
+
61
+ if image is None:
62
+ return "Error: Could not capture image from webcam. Please try again.", None
63
+
64
+ # Ensure image is in RGB format and apply transform
65
+ image = image.convert("RGB")
66
+ image = val_transform(image).unsqueeze(0) # Add batch dimension
67
+
68
+ # Move the image to the device (assuming GPU is available)
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ image = image.to(device)
71
+ model.to(device) # Move the model to the device as well
72
+
73
+
74
+ # Perform inference
75
+ with torch.no_grad():
76
+ outputs = model(image)
77
+ # Get the predicted class index
78
+ _, predicted_idx = torch.max(outputs.data, 1)
79
+
80
+ # Get the predicted class name
81
+ class_names = ['bleyla', 'byjd', 'falafel', 'jenny', 'sachu']
82
+ predicted_class = class_names[predicted_idx.item()]
83
+
84
+ # Generate audio
85
+ if predicted_class == "falafel":
86
+ text_to_speak = random.choice(falafel_audio)
87
+ elif predicted_class == "sachu":
88
+ text_to_speak = random.choice(sachu_audio)
89
+ elif predicted_class == "jenny":
90
+ text_to_speak = random.choice(jenny_audio)
91
+ elif predicted_class == "bleyla":
92
+ text_to_speak = random.choice(bleyla_audio)
93
+ elif predicted_class == "byjd":
94
+ text_to_speak = random.choice(byjd_audio)
95
+ else:
96
+ text_to_speak = "Unknown class"
97
+
98
+ tts = gTTS(text=text_to_speak, lang='bg')
99
+ audio_file = f"prediction_{uuid.uuid4()}.mp3"
100
+ tts.save(audio_file)
101
+
102
+ # Ensure file cleanup
103
+ cleanup_audio_files()
104
+
105
+ return predicted_class, audio_file
106
+
107
+ # Create the Gradio interface
108
+ interface = gr.Interface(
109
+ fn=classify_face_with_audio_new,
110
+ inputs=gr.Image(type="pil", label="Upload an image or use your camera"),
111
+ outputs=[
112
+ gr.Textbox(label="Predicted Class"),
113
+ gr.Audio(label="Audio Pronunciation")
114
+ ],
115
+ title="Russian Monument Classifier",
116
+ description="Upload an image or use your camera to classify Russian Monument Citizens.",
117
+ examples=[["examples/bleyla_new.jpg"], ["examples/byjd_new.jpg"], ["examples/falafelcho.jpg"]] # Examples should be a list of lists
118
+ )
119
+
120
+ # Launch the interface
121
+ if __name__ == "__main__":
122
+ interface.launch()
examples/bleyla_new.jpg ADDED

Git LFS Details

  • SHA256: 9f94dddc2c6a3a23892b1b85bf34503f6a3e78213d701cca6475ee0a652e09ed
  • Pointer size: 131 Bytes
  • Size of remote file: 231 kB
examples/byjd_new.jpg ADDED

Git LFS Details

  • SHA256: ee0907b096fcc4abfd6e7deed5f39c74093223330513d9ff0df88f4334439d6e
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
examples/falafelcho.jpg ADDED

Git LFS Details

  • SHA256: 4b811fb513993892f2434af9dd8840b0ce8fe401edda857f1b0e5f1017a2dc19
  • Pointer size: 131 Bytes
  • Size of remote file: 248 kB
model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models import resnet18, ResNet18_Weights
3
+ from torch import nn
4
+
5
+ def load_face_classifier_model(model_path: str = 'model_2.pth', num_classes: int = 5):
6
+ """
7
+ Loads the pre-trained ResNet18 model, modifies the final layer,
8
+ loads the state dictionary, and sets the model to evaluation mode.
9
+
10
+ Args:
11
+ model_path (str): Path to the saved model state dictionary.
12
+ num_classes (int): Number of classes for the final linear layer.
13
+
14
+ Returns:
15
+ torch.nn.Module: The loaded model in evaluation mode.
16
+ """
17
+ # Load the pre-trained ResNet18 model with specified weights
18
+ weights = ResNet18_Weights.IMAGENET1K_V1
19
+ model = resnet18(weights=weights)
20
+
21
+ # Modify the final fully connected layer for the specified number of classes
22
+ num_ftrs = model.fc.in_features
23
+ model.fc = nn.Linear(num_ftrs, num_classes)
24
+
25
+ # Load the saved state dictionary
26
+ state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Load to CPU
27
+
28
+ # Adjust keys to match the model (if necessary, based on how the model was saved)
29
+ # This adjustment is based on the observation from the previous failed attempt.
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ if 'fc.1.' in k:
33
+ new_key = k.replace('fc.1.', 'fc.')
34
+ new_state_dict[new_key] = v
35
+ else:
36
+ new_state_dict[k] = v
37
+
38
+ model.load_state_dict(new_state_dict)
39
+
40
+ # Set the model to evaluation mode
41
+ model.eval()
42
+
43
+ return model
44
+
45
+ if __name__ == '__main__':
46
+ # Example usage (for testing)
47
+ loaded_model = load_face_classifier_model()
48
+ print("Model loaded successfully:")
49
+ print(loaded_model)
model_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84140f841ffe511330eb0a18b96bf665b341f7759176d8a6885787d7aa2e2a1d
3
+ size 44793355
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==3.1.4
2
+ torch==2.8.0
3
+ torchvision==0.23.0
4
+ Pillow
5
+ gtts