MaryahGreene commited on
Commit
b3a69c7
·
verified ·
1 Parent(s): 0b3d339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -8,13 +8,15 @@ tokenizer = AutoTokenizer.from_pretrained("MaryahGreene/arch_flava_mod", trust_r
8
  id2label = model.config.id2label # make sure this is set during training!
9
 
10
  def predict(text):
11
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
12
- with torch.no_grad():
13
  outputs = model(**inputs)
14
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
15
- top_label = torch.argmax(probs, dim=1).item()
16
- label_name = id2label[str(top_label)]
17
- confidence = probs[0][top_label].item()
18
- return f"Prediction: {label_name} ({confidence:.2%} confidence)"
 
 
19
 
20
  gr.Interface(fn=predict, inputs="text", outputs="text", title="ArchFlava Predictor").launch()
 
8
  id2label = model.config.id2label # make sure this is set during training!
9
 
10
  def predict(text):
11
+ try:
12
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
13
  outputs = model(**inputs)
14
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
15
+ top_label = torch.argmax(probs, dim=1).item()
16
+ label_name = id2label[str(top_label)]
17
+ confidence = probs[0][top_label].item()
18
+ return f"Prediction: {label_name} ({confidence:.2%} confidence)"
19
+ except Exception as e:
20
+ return f"❌ Error: {str(e)}"
21
 
22
  gr.Interface(fn=predict, inputs="text", outputs="text", title="ArchFlava Predictor").launch()