File size: 3,499 Bytes
915cbc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi.routing import APIRouter
from pydantic import BaseModel
from Pipeline.HallucinationPipeline import HallucinationPipeline
from Pipeline.CorrectionLLMs import DeepseekAPI
from nltk.tokenize import sent_tokenize
from utils import detectionProcess
import os


deepseek_apikey="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"


router=APIRouter(prefix="/api")
pipeline=HallucinationPipeline("Razor2507/Roberta-Base-Finetuned","cpu")
deepseek=DeepseekAPI(api_key=deepseek_apikey)

# {'predictions': [0],
#  'corrected_summary': [],
#  'sent_predicted': [array([2, 0, 0, 0])],
#  'factual_score': [0.21776947937905788],
#  'contradiction_score': [0.7815729230642319]}



# Detection Endpoint
class DetectionRequest(BaseModel):
    article:str
    summary:str
    arbiter:str

@router.post("/detect")
def detect(data:DetectionRequest):
    try:
        article=data.article.strip().replace("\n"," ").replace("\t"," ")
        summary=data.summary.strip().replace("\n"," ").replace("\t"," ")
        arbiter=True if data.arbiter=="on" else False
        print("Arbiter : ",arbiter)
        result=detectionProcess(article=article,summary=summary,pipeline=pipeline,arbiter=arbiter)
        result["status"]=200
        return result
    except Exception as e:
        print(e)
        return {"status":404}



#  Correction Endpoint
class correctionRequest(BaseModel):
    article:str
    tag_summary:str
    model:str

@router.post("/correct")
def correct(data:correctionRequest):
    try:
        
        if data.model=="mistral":
            pass
        elif data.model=="gemini":
            pass
        elif data.model=="deepseek":
            correction=deepseek.correct(premise=data.article, summary=data.tag_summary)
        
        print(correction)

        result=detectionProcess(article=data.article,summary=correction,pipeline=pipeline)
        result["corrected_summary"]=correction
        result["status"]=200
        return result
    except Exception as e:
        print(e)
        return {"status":404}
    

@router.get("/test")
def keyTest():
    print("Testing ",os.getenv("deepseek_apikey"))
    return {"msg":"testing_works"}







# @router.post("/detect")
# def detect(data:DetectionRequest):
    # article=data.article.strip().replace("\n","").replace("\t"," ")
    # summary=data.summary.strip().replace("\n","").replace("\t"," ")

#     result=pipeline.process([[article,summary]],correct_the_summary=False)
#     all_sentences=sent_tokenize(summary)
#     print(result)
#     summary=pipeline.addTags(all_sentences,result["sent_predicted"][0],len(all_sentences))
#     score=str(result["factual_score"][0])
#     sentenceLabels=list(result["sent_predicted"][0])
#     labelCounts=[sentenceLabels.count(0),sentenceLabels.count(2)]
  
#     prompt=f"""
# Here is a summary with hallucinated parts marked using <xx> tags.

# Please correct only the text inside the <xx> tags to make it factually accurate based on the original article. Leave the rest of the summary unchanged and remove the <xx> tags after correction.

# Return the summary with hallucinated parts fixed and you can remove those <xx></xx> tags. Don't remove that entire sentence.
            
# Original Article:
#     {data.article}

# Summary:
#     {summary}

# """
    # return {"summary":summary,"score":score,"counts":labelCounts,"copy_prompt":prompt}