KenLumod commited on
Commit
f17df59
·
verified ·
1 Parent(s): 02b4b9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -71
app.py CHANGED
@@ -1,77 +1,44 @@
 
 
1
  import torch
2
- import os
3
- from transformers import BertConfig, AutoModelForSequenceClassification, BertTokenizerFast
4
 
5
- def save_model(model, tokenizer, out_dir: str):
6
- """Saves model in Hugging Face-compatible format"""
7
- os.makedirs(out_dir, exist_ok=True)
8
-
9
- # 1. Save full model architecture
10
- model.bert.save_pretrained(out_dir)
11
-
12
- # 2. Save custom classifier weights with standard name
13
- torch.save(model.state_dict(), os.path.join(out_dir, "model_state.bin"))
14
-
15
- # 3. Create compatible config
16
- config = BertConfig.from_pretrained("bert-base-uncased")
17
- config.update({
18
- "num_labels": 2,
19
- "id2label": {0: "REAL", 1: "FAKE"},
20
- "hidden_dropout_prob": 0.1,
21
- "classifier_dropout": 0.1,
22
- "model_type": "bert-for-sequence-classification"
23
- })
24
- config.save_pretrained(out_dir)
25
-
26
- # 4. Save tokenizer
27
- tokenizer.save_pretrained(out_dir)
28
- print(f"✅ Model saved in HF format to {out_dir}")
29
 
30
- def load_model(model_dir: str, device=torch.device("cpu")):
31
- """Loads model using AutoModelForSequenceClassification"""
32
- from transformers import AutoModelForSequenceClassification
33
-
34
- # Load with custom config
35
- model = AutoModelForSequenceClassification.from_pretrained(
36
- model_dir,
37
- config=BertConfig.from_pretrained(model_dir)
38
- )
39
-
40
- # Load custom weights
41
- state_dict = torch.load(os.path.join(model_dir, "model_state.bin"))
42
- model.load_state_dict(state_dict)
43
-
44
- # Load tokenizer
45
- tokenizer = BertTokenizerFast.from_pretrained(model_dir)
46
-
47
- model.to(device).eval()
48
- print(f"✅ Model loaded from {model_dir}")
49
- return model, tokenizer
50
 
51
- # ------------------------
52
- # 7) Updated Inference
53
- # ------------------------
54
- def predict(texts, model, tokenizer, device, max_length=15):
55
- """Batch prediction with proper HF formatting"""
56
- encodings = tokenizer(
57
- texts,
58
- max_length=max_length,
59
- truncation=True,
60
- padding="max_length",
61
- return_tensors="pt"
62
- ).to(device)
63
-
64
- with torch.no_grad():
65
- outputs = model(**encodings)
66
-
67
- return torch.argmax(outputs.logits, dim=1).cpu().numpy()
68
 
69
- # Usage
70
- best_model, best_tokenizer = load_model("KenLumod/ML-Fake-Real-News-Detector-Final")
 
 
 
 
 
71
 
72
- new_texts = [
73
- "NASA announces alien life discovery",
74
- "Congress passes new infrastructure bill"
75
- ]
76
- predictions = predict(new_texts, best_model, best_tokenizer, device)
77
- print("Predictions:", ["FAKE" if p == 1 else "REAL" for p in predictions])
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoConfig, AutoModel
3
  import torch
 
 
4
 
5
+ model_id = "KenLumod/ML-Fake-Real-News-Detector-Final"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Force reload model with updated config
8
+ config = AutoConfig.from_pretrained(model_id)
9
+ config.id2label = {1: "Fake News", 0: "Real News"} # Force override
10
+ config.label2id = {v: k for k, v in config.id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load the model using AutoModel (ensure safetensors format is handled)
13
+ model = AutoModel.from_pretrained(
14
+ model_id,
15
+ config=config,
16
+ # Specify safe_tensors=True if using safetensors
17
+ safetensors=True
18
+ )
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Create the pipeline for classification
21
+ classifier = pipeline(
22
+ "text-classification",
23
+ model=model,
24
+ tokenizer=model_id,
25
+ return_all_scores=False
26
+ )
27
 
28
+ def classify_news(text):
29
+ result = classifier(text)[0]
30
+ return result['label']
31
+
32
+ demo = gr.Interface(
33
+ fn=classify_news,
34
+ inputs=gr.Textbox(lines=6, placeholder="Enter news article here..."),
35
+ outputs="text",
36
+ title="Fake News Detector",
37
+ description="Classifies news articles as Fake or Real",
38
+ examples=[
39
+ ["Breaking: Scientists discover chocolate prevents aging!"],
40
+ ["Parliament passes new climate change legislation"]
41
+ ]
42
+ )
43
+
44
+ demo.launch()