File size: 6,129 Bytes
b729f12
adfc68f
3931752
a6dbdd5
5fca79f
9c191f5
a4a8598
 
a6dbdd5
85bca67
c9523a5
 
a6dbdd5
c9523a5
a6dbdd5
 
c9523a5
 
b4bdb8a
 
a6dbdd5
c9523a5
a6dbdd5
c9523a5
5fca79f
 
 
3d7f973
 
 
 
 
 
 
 
 
 
 
 
 
c9523a5
5fca79f
 
c9523a5
5fca79f
c9523a5
 
5fca79f
 
b729f12
5fca79f
 
 
 
 
c9523a5
5fca79f
 
 
e5b1901
 
 
c9523a5
 
 
 
5fca79f
 
 
 
 
 
 
 
b729f12
c9523a5
 
 
b729f12
c9523a5
b729f12
 
 
5fca79f
 
 
 
 
b729f12
c9523a5
 
 
b729f12
 
 
 
 
 
 
e5b1901
 
a6dbdd5
3d7f973
 
b4bdb8a
e5b1901
b4bdb8a
 
 
 
 
 
 
c9523a5
b4bdb8a
 
 
e5b1901
b4bdb8a
 
 
 
e5b1901
b4bdb8a
 
e5b1901
b4bdb8a
c9523a5
 
 
 
 
 
 
 
 
 
 
 
 
3d7f973
 
c9523a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["MPLCONFIGDIR"] = "/tmp/mplconfig"
os.environ["HF_HOME"] = "/tmp"

os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/mplconfig", exist_ok=True)

os.environ["HF_HOME"] = "/tmp"

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import httpx
import io
import logging
import random

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Coin symbol to CoinGecko ID mapping
SYMBOL_TO_ID = {
    "btc": "bitcoin",
    "eth": "ethereum",
    "xrp": "ripple",
    "ltc": "litecoin",
    "ada": "cardano",
    "doge": "dogecoin",
    "sol": "solana",
    # Add more if needed
}


# FastAPI app
app = FastAPI()

# Load models
try:
    tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
    model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
    ner_model = pipeline("ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
    sentiment_model = pipeline("sentiment-analysis", model="ProsusAI/finbert")
    logger.info("Models loaded successfully.")
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    ner_model = None
    sentiment_model = None

# Schemas
class TextRequest(BaseModel):
    text: str

class CoinRequest(BaseModel):
    coin_id: str

class VisualRequest(BaseModel):
    coin_id: str
    topic: str

@app.get("/")
def home():
    return {"message": "Crypto News API is alive!"}

@app.post("/sentiment")
def analyze_sentiment(req: TextRequest):
    if not sentiment_model:
        raise HTTPException(status_code=503, detail="Sentiment model not available")
    try:
        text = req.text.strip()
        if not text:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
        result = sentiment_model(text[:512])[0]
        return {"label": result["label"], "score": round(result["score"] * 100, 2)}
    except Exception as e:
        logger.error(f"Sentiment analysis error: {e}")
        raise HTTPException(status_code=500, detail="Sentiment analysis failed")

@app.post("/ner")
def analyze_ner(req: TextRequest):
    if not ner_model:
        raise HTTPException(status_code=503, detail="NER model not available")
    try:
        text = req.text.strip()
        if not text:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
        entities = ner_model(text[:512])
        relevant = [e['word'] for e in entities if e.get('entity_group') in ['ORG', 'PERSON', 'MISC', 'PRODUCT', 'GPE']]
        unique_entities = list(dict.fromkeys(relevant))[:5]
        return {"entities": unique_entities}
    except Exception as e:
        logger.error(f"NER analysis error: {e}")
        raise HTTPException(status_code=500, detail="NER analysis failed")

@app.post("/chart")
def generate_chart(req: CoinRequest):
    coin_symbol = req.coin_id.strip().lower()
    coin_id = SYMBOL_TO_ID.get(coin_symbol, coin_symbol)
    logger.info(f"Generating chart for coin: {coin_id}")
    try:
        url = f"https://api.coingecko.com/api/v3/coins/{coin_id}/market_chart"
        params = {"vs_currency": "usd", "days": "7"}
        response = httpx.get(url, params=params)
        if response.status_code != 200:
            logger.error(f"CoinGecko API error: {response.text}")
            raise HTTPException(status_code=502, detail="Failed to fetch coin data from CoinGecko")
        prices = response.json()["prices"]
        _, values = zip(*prices)
        plt.figure(figsize=(6, 3))
        plt.plot(values, color="blue")
        plt.title(f"{coin_id.capitalize()} - Last 7 Days")
        plt.xlabel("Time")
        plt.ylabel("Price (USD)")
        plt.grid(True)
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png")
        plt.close()
        buffer.seek(0)
        return StreamingResponse(buffer, media_type="image/png")
    except Exception as e:
        logger.exception(f"Chart generation error: {e}")
        raise HTTPException(status_code=500, detail="Chart generation failed")

# ✅ News image generator
def generate_news_image(topic: str) -> str:
    file_path = f"/tmp/{topic.replace(' ', '_')}_news.png"
    plt.figure(figsize=(6, 3))
    plt.text(0.5, 0.5, f"📰 {topic}", fontsize=18, ha='center')
    plt.axis("off")
    plt.savefig(file_path)
    plt.close()
    return file_path

# ✅ Chart image generator for visual endpoint (reuse)
def generate_chart_image(coin_symbol: str) -> str:
    coin_id = SYMBOL_TO_ID.get(coin_symbol.lower(), coin_symbol.lower())
    try:
        url = f"https://api.coingecko.com/api/v3/coins/{coin_id}/market_chart"
        params = {"vs_currency": "usd", "days": "7"}
        response = httpx.get(url, params=params)
        if response.status_code != 200:
            raise Exception("CoinGecko data fetch failed")
        prices = response.json()["prices"]
        _, values = zip(*prices)
        file_path = f"/tmp/{coin_id.replace(' ', '_')}_chart.png"
        plt.figure(figsize=(6, 3))
        plt.plot(values, color="green")
        plt.title(f"{coin_id.capitalize()} Chart")
        plt.grid(True)
        plt.savefig(file_path)
        plt.close()
        return file_path
    except Exception as e:
        logger.error(f"Chart image generation error: {e}")
        raise

# ✅ Random visual endpoint
@app.post("/visual")
def generate_visual(req: VisualRequest):
    choice = random.choice(["chart", "news"])
    logger.info(f"Generating visual: {choice}")
    try:
        if choice == "chart":
            path = generate_chart_image(req.coin_id)
        else:
            path = generate_news_image(req.topic)
        return FileResponse(path, media_type="image/png")
    except Exception as e:
        logger.error(f"Visual generation failed: {e}")
        raise HTTPException(status_code=500, detail="Visual generation failed")