Fluospark128 commited on
Commit
b1f9cb7
·
verified ·
1 Parent(s): 206d25e

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +176 -0
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from transformers import GPT2Tokenizer, GPT2Model
4
+ from sklearn.preprocessing import MultiLabelBinarizer
5
+ from torch import nn
6
+ import torch
7
+ import openai
8
+ from collections import Counter
9
+ import nltk
10
+ from nltk.corpus import stopwords
11
+ from nltk.tokenize import word_tokenize
12
+
13
+ class GenreClassifier(nn.Module):
14
+ def __init__(self, num_genres=20):
15
+ super().__init__()
16
+ self.gpt2 = GPT2Model.from_pretrained('gpt2')
17
+ self.dropout = nn.Dropout(0.1)
18
+ self.genre_classifier = nn.Linear(768, num_genres) # 768 is GPT2's hidden size
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ def forward(self, input_ids, attention_mask):
22
+ outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
23
+ pooled_output = outputs[0].mean(dim=1) # Average pooling
24
+ pooled_output = self.dropout(pooled_output)
25
+ genre_logits = self.genre_classifier(pooled_output)
26
+ return self.sigmoid(genre_logits)
27
+
28
+ class BookGenreAnalyzer:
29
+ def __init__(self, api_key):
30
+ """Initialize the analyzer with OpenAI API key"""
31
+ self.openai.api_key = api_key
32
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
33
+ self.model = GenreClassifier()
34
+ self.genre_labels = self._load_genre_labels()
35
+ nltk.download('punkt')
36
+ nltk.download('stopwords')
37
+ self.stop_words = set(stopwords.words('english'))
38
+
39
+ def _load_genre_labels(self):
40
+ """Load predefined genre labels"""
41
+ # You would typically load these from a file or database
42
+ return [
43
+ "Fiction", "Non-fiction", "Mystery", "Romance", "Science Fiction",
44
+ "Fantasy", "Thriller", "Horror", "Historical Fiction", "Biography",
45
+ "Self-help", "Business", "Science", "Philosophy", "Poetry",
46
+ "Drama", "Adventure", "Literary Fiction", "Young Adult", "Children's"
47
+ ]
48
+
49
+ def preprocess_text(self, text):
50
+ """Preprocess the book text"""
51
+ # Tokenize and remove stop words
52
+ tokens = word_tokenize(text.lower())
53
+ tokens = [t for t in tokens if t not in self.stop_words]
54
+
55
+ # Convert to GPT2 tokens
56
+ encodings = self.tokenizer(
57
+ ' '.join(tokens),
58
+ truncation=True,
59
+ max_length=1024,
60
+ padding='max_length',
61
+ return_tensors='pt'
62
+ )
63
+ return encodings
64
+
65
+ def extract_features(self, text):
66
+ """Extract relevant features from the text"""
67
+ encodings = self.preprocess_text(text)
68
+ with torch.no_grad():
69
+ features = self.model(
70
+ input_ids=encodings['input_ids'],
71
+ attention_mask=encodings['attention_mask']
72
+ )
73
+ return features
74
+
75
+ def fine_tune_with_gpt3(self, training_data):
76
+ """Fine-tune the model using GPT-3"""
77
+ # Prepare training data in the format expected by OpenAI
78
+ formatted_data = []
79
+ for book_text, genres in training_data:
80
+ formatted_data.append({
81
+ "prompt": f"Book text: {book_text[:1000]}...\nGenres:",
82
+ "completion": f" {', '.join(genres)}"
83
+ })
84
+
85
+ # Create fine-tuning job
86
+ try:
87
+ response = openai.FineTune.create(
88
+ training_file=self._upload_training_data(formatted_data),
89
+ model="gpt-3",
90
+ n_epochs=3,
91
+ batch_size=4,
92
+ learning_rate_multiplier=0.1
93
+ )
94
+ return response
95
+ except Exception as e:
96
+ print(f"Fine-tuning error: {e}")
97
+ return None
98
+
99
+ def _upload_training_data(self, formatted_data):
100
+ """Upload training data to OpenAI"""
101
+ import json
102
+ with open('training_data.jsonl', 'w') as f:
103
+ for entry in formatted_data:
104
+ json.dump(entry, f)
105
+ f.write('\n')
106
+
107
+ with open('training_data.jsonl', 'rb') as f:
108
+ response = openai.File.create(
109
+ file=f,
110
+ purpose='fine-tune'
111
+ )
112
+ return response.id
113
+
114
+ def analyze_book(self, book_text):
115
+ """Analyze a book and return top 20 genres with confidence scores"""
116
+ # Get base predictions from our model
117
+ features = self.extract_features(book_text)
118
+ predictions = features.numpy()[0]
119
+
120
+ # Use GPT-3 to enhance predictions
121
+ try:
122
+ response = openai.Completion.create(
123
+ model="gpt-3", # Use fine-tuned model ID if available
124
+ prompt=f"Book text: {book_text[:1000]}...\nGenres:",
125
+ max_tokens=100,
126
+ temperature=0.3
127
+ )
128
+ gpt3_genres = response.choices[0].text.strip().split(', ')
129
+ except:
130
+ gpt3_genres = []
131
+
132
+ # Combine both predictions
133
+ genres_with_scores = [
134
+ (genre, float(score))
135
+ for genre, score in zip(self.genre_labels, predictions)
136
+ ]
137
+
138
+ # Boost scores for genres mentioned by GPT-3
139
+ for genre, score in genres_with_scores:
140
+ if genre in gpt3_genres:
141
+ score *= 1.2
142
+
143
+ # Sort and return top 20
144
+ return sorted(genres_with_scores, key=lambda x: x[1], reverse=True)[:20]
145
+
146
+ # Example usage
147
+ def main():
148
+ # Initialize analyzer
149
+ analyzer = BookGenreAnalyzer('your-api-key')
150
+
151
+ # Example book text
152
+ book_text = """
153
+ [Your book text here]
154
+ """
155
+
156
+ # Get genre predictions
157
+ genres = analyzer.analyze_book(book_text)
158
+
159
+ # Print results
160
+ print("\nTop 20 Genres:")
161
+ for genre, confidence in genres:
162
+ print(f"{genre}: {confidence:.2%}")
163
+
164
+ # Example of fine-tuning
165
+ training_data = [
166
+ ("Book 1 text...", ["Mystery", "Thriller"]),
167
+ ("Book 2 text...", ["Science Fiction", "Adventure"]),
168
+ # Add more training examples
169
+ ]
170
+
171
+ fine_tune_response = analyzer.fine_tune_with_gpt3(training_data)
172
+ if fine_tune_response:
173
+ print("\nFine-tuning job created successfully!")
174
+
175
+ if __name__ == "__main__":
176
+ main()