| | import time |
| | import gradio as gr |
| | from transformers import pipeline |
| | from huggingface_hub import InferenceClient |
| | from typing import List, Dict, Tuple, Any, Optional |
| | from diffusers import AutoPipelineForText2Image |
| | import torch |
| |
|
| | |
| | MAX_CHAR = 8000 |
| | NER_NUM_ROWS = 10 |
| |
|
| | |
| | SUMM_MODEL_ID = "sshleifer/distilbart-cnn-12-6" |
| | SENTIMENT_MODEL_ID = "ahmedrachid/FinancialBERT-Sentiment-Analysis" |
| | FINCLS_MODEL_ID = "nickmuchi/distilroberta-finetuned-financial-text-classification" |
| | NER_MODEL_ID = "dslim/bert-base-NER" |
| | CHAT_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" |
| | IMAGE_MODEL_ID = "stabilityai/sd-turbo" |
| |
|
| | _summ_pipe = None |
| | _sentiment_pipe = None |
| | _fincls_pipe = None |
| | _ner_pipe = None |
| | _img_pipe_cpu = None |
| |
|
| | |
| | IMG_STEPS = 2 |
| | IMG_GUIDANCE = 0.5 |
| | IMG_WIDTH = 512 |
| | IMG_HEIGHT = 512 |
| |
|
| | |
| | CHAT_MAX_TOKENS = 512 |
| | CHAT_TEMPERATURE = 0.7 |
| | CHAT_TOP_P = 0.95 |
| | CHAT_SYSTEM_PROMPT = ("\nYou are assisting with analysis of a financial news article." |
| | + "\nBe clear, cite facts from context, and avoid investment advice." |
| | + "\nUse the provided ARTICLE as your primary context." |
| | + "\nIf the user asks about something not in context, say what you do/don't know." |
| | ) |
| |
|
| | DEVICE_CPU = -1 |
| |
|
| | |
| | def _get_summ_pipe(): |
| | global _summ_pipe |
| | if _summ_pipe is None: |
| | _summ_pipe = pipeline( |
| | "summarization", |
| | model=SUMM_MODEL_ID, |
| | device=DEVICE_CPU, |
| | ) |
| | return _summ_pipe |
| |
|
| | def _get_sentiment_pipe(): |
| | global _sentiment_pipe |
| | if _sentiment_pipe is None: |
| | _sentiment_pipe = pipeline( |
| | "text-classification", |
| | model=SENTIMENT_MODEL_ID, |
| | truncation=True, |
| | device=DEVICE_CPU, |
| | ) |
| | return _sentiment_pipe |
| |
|
| | def _get_fincls_pipe(): |
| | global _fincls_pipe |
| | if _fincls_pipe is None: |
| | _fincls_pipe = pipeline( |
| | "text-classification", |
| | model=FINCLS_MODEL_ID, |
| | truncation=True, |
| | return_all_scores=True, |
| | device=DEVICE_CPU, |
| | ) |
| | return _fincls_pipe |
| |
|
| | def _get_ner_pipe(): |
| | global _ner_pipe |
| | if _ner_pipe is None: |
| | _ner_pipe = pipeline( |
| | "token-classification", |
| | model=NER_MODEL_ID, |
| | aggregation_strategy="simple", |
| | device=DEVICE_CPU, |
| | ) |
| | return _ner_pipe |
| |
|
| | |
| | |
| | def _hf_token_str(hf_token): |
| | if hf_token is None: |
| | return None |
| | if isinstance(hf_token, str): |
| | return hf_token or None |
| | |
| | if hasattr(hf_token, "token"): |
| | return hf_token.token |
| | |
| | if isinstance(hf_token, dict): |
| | return hf_token.get("token") |
| | return None |
| |
|
| | def _get_img_pipe_cpu(): |
| | global _img_pipe_cpu |
| | if _img_pipe_cpu is None: |
| | pipe = AutoPipelineForText2Image.from_pretrained( |
| | IMAGE_MODEL_ID, |
| | torch_dtype=torch.float32, |
| | use_safetensors=True, |
| | ) |
| | pipe.to("cpu") |
| | for fn in ("enable_attention_slicing", "enable_vae_slicing"): |
| | try: |
| | getattr(pipe, fn)() |
| | except Exception: |
| | pass |
| | _img_pipe_cpu = pipe |
| | return _img_pipe_cpu |
| |
|
| | def _try_cloud_text2image(prompt: str, hf_token: Optional[gr.OAuthToken]): |
| | tok = getattr(hf_token, "token", None) if hf_token else None |
| | if not tok: |
| | return None |
| | try: |
| | client = InferenceClient(token=tok) |
| | return client.text_to_image(prompt, model=IMAGE_MODEL_ID) |
| | except Exception: |
| | return None |
| |
|
| | |
| | def _normalize_text(text: str, max_len: int = MAX_CHAR) -> str: |
| | return (text or "").strip()[:max_len] |
| |
|
| | def run_summary(text: str) -> str: |
| | try: |
| | txt = _normalize_text(text, MAX_CHAR) |
| | if not txt: |
| | return "" |
| | sp = _get_summ_pipe() |
| | out = sp(txt[:3000], max_length=160, min_length=48, do_sample=False) |
| | return out[0]["summary_text"].strip() if out else "" |
| | except Exception as e: |
| | print("Summary error:", e) |
| | return "" |
| |
|
| | def run_text_nlp(text: str) -> Tuple[str, float, str, float]: |
| | try: |
| | txt = _normalize_text(text) |
| | if not txt: |
| | return "", 0.0, "", 0.0 |
| | sp = _get_sentiment_pipe() |
| | fp = _get_fincls_pipe() |
| | s_pred = sp(txt)[0] |
| | dist = fp(txt)[0] |
| | top = max(dist, key=lambda d: d["score"]) if dist else {"label": "", "score": 0.0} |
| | return ( |
| | s_pred.get("label", ""), |
| | float(s_pred.get("score", 0.0)), |
| | top.get("label", ""), |
| | float(top.get("score", 0.0)), |
| | ) |
| | except Exception as e: |
| | print("Text NLP error:", e) |
| | return "Error", 0.0, "Error", 0.0 |
| |
|
| | def run_ner_rows(text: str, limit: int = NER_NUM_ROWS) -> List[List[str]]: |
| | try: |
| | txt = _normalize_text(text, MAX_CHAR) |
| | if not txt: |
| | return [] |
| | ner = _get_ner_pipe() |
| | ents = ner(txt) |
| | rows = [ |
| | [e.get("entity_group", ""), e.get("word", ""), f"{float(e.get('score', 0.0)):.2f}"] |
| | for e in ents |
| | ] |
| | return rows[:limit] |
| | except Exception as e: |
| | print("NER error:", e) |
| | return [["Error", str(e), "0.00"]] |
| |
|
| | |
| | def build_context_block(article: str, analysis: Dict[str, Any]) -> str: |
| | parts = [] |
| | if article: |
| | parts.append(f"ARTICLE (truncated):\n{article[:MAX_CHAR]}") |
| | if analysis: |
| | parts.append( |
| | "ANALYSIS SUMMARY:\n" |
| | f"- Sentiment: {analysis.get('sentiment')} ({analysis.get('sentiment_score'):.2f})\n" |
| | f"- Financial stance: {analysis.get('category')} ({analysis.get('category_score'):.2f})" |
| | ) |
| | if analysis.get("summary"): |
| | parts.append(f"- Auto Summary: {analysis['summary']}") |
| | ents = analysis.get("entities", []) |
| | if ents: |
| | ent_str = ", ".join({r[1] for r in ents[:40]}) |
| | parts.append(f"- Top entities: {ent_str}") |
| | return "\n\n".join(parts) |
| |
|
| | def _warn_if_no_token(hf_token: Optional[gr.OAuthToken]) -> str: |
| | if not hf_token or not getattr(hf_token, "token", None): |
| | return "\nYou are not logged in to Hugging Face. Click **Login** (left sidebar) for better reliability.\n\n" |
| | return "" |
| |
|
| | def respond_chat( |
| | message: str, |
| | history: List[Dict[str, str]], |
| | article_text: str, |
| | analysis: Dict[str, Any], |
| | hf_token: gr.OAuthToken, |
| | _profile, |
| | ): |
| | tok = _hf_token_str(hf_token) |
| |
|
| | login_notice = _warn_if_no_token(hf_token) |
| |
|
| | client = InferenceClient( |
| | token=tok, |
| | model=CHAT_MODEL_ID |
| | ) |
| |
|
| | context_block = build_context_block(article_text or "", analysis or {}) |
| | sys = (CHAT_SYSTEM_PROMPT) |
| |
|
| | messages = [ |
| | {"role": "system", "content": sys}, |
| | {"role": "system", "content": context_block}, |
| | *history, |
| | {"role": "user", "content": message}, |
| | ] |
| |
|
| | response = login_notice |
| | try: |
| | stream = client.chat_completion( |
| | messages, |
| | max_tokens=CHAT_MAX_TOKENS, |
| | stream=True, |
| | temperature=CHAT_TEMPERATURE, |
| | top_p=CHAT_TOP_P, |
| | ) |
| | for chunk in stream: |
| | choices = getattr(chunk, "choices", []) |
| | piece = "" |
| | if choices and getattr(choices[0], "delta", None) and choices[0].delta.content: |
| | piece = choices[0].delta.content |
| | response += piece |
| | yield response |
| | except Exception as e: |
| | response += ( |
| | f"\nChat request failed for model `{CHAT_MODEL_ID}`.\n" |
| | f"Error: {e}\n" |
| | ) |
| | yield response |
| |
|
| | |
| | def generate_image(prompt, width, height, hf_token, *args): |
| | import traceback |
| | t0 = time.time() |
| | prompt = (prompt or "").strip() |
| | if not prompt: |
| | return None, "Provide a prompt." |
| |
|
| | |
| | try: |
| | img = _try_cloud_text2image(prompt, hf_token) |
| | if img is not None: |
| | return img, f"{time.time()-t0:.2f}s" |
| | except Exception as e: |
| | print("Cloud image error:", e) |
| | traceback.print_exc() |
| |
|
| | |
| | try: |
| | pipe = _get_img_pipe_cpu() |
| | width, height = int(width), int(height) |
| | out = pipe( |
| | prompt=prompt, |
| | num_inference_steps=IMG_STEPS, |
| | guidance_scale=IMG_GUIDANCE, |
| | width=width, |
| | height=height, |
| | ) |
| | img = out.images[0] |
| | return img, f"{time.time()-t0:.2f}s | steps={IMG_STEPS}, g={IMG_GUIDANCE}" |
| | except Exception as e: |
| | print("CPU image error:", e) |
| | traceback.print_exc() |
| | return None, f"Generation failed: {e}" |
| |
|
| | |
| | with gr.Blocks(fill_height=True) as demo: |
| | gr.Markdown("**ARIN 460 Final — Financial News Multi-Model**") |
| |
|
| | article_state = gr.State("") |
| | analysis_state = gr.State({}) |
| |
|
| | with gr.Sidebar(): |
| | login_btn = gr.LoginButton() |
| | gr.Markdown("**Workflow**\n1) Input\n2) Analysis (Assignment 4)\n3) Chat\n4) Image") |
| |
|
| | with gr.Tabs(): |
| | with gr.Tab("Input"): |
| | txt_in = gr.Textbox(lines=12, label="Article text") |
| | analyze_btn = gr.Button("Analyze", variant="primary") |
| | run_status = gr.Markdown() |
| |
|
| | with gr.Tab("Text Analysis"): |
| | summary_box = gr.Textbox(label="Summary", lines=4, interactive=False) |
| | sent_lbl = gr.Textbox(label="Sentiment", interactive=False) |
| | sent_score = gr.Number(label="Sentiment score", precision=3, interactive=False) |
| | fin_lbl = gr.Textbox(label="Financial Category", interactive=False) |
| | fin_score = gr.Number(label="Category score", precision=3, interactive=False) |
| | ta_status = gr.Markdown() |
| |
|
| | with gr.Tab("NER"): |
| | ner_out = gr.Dataframe(headers=["entity", "text", "score"], |
| | datatype=["str", "str", "str"], interactive=False) |
| | ner_status = gr.Markdown() |
| |
|
| | with gr.Tab("Chat"): |
| | chat = gr.ChatInterface( |
| | respond_chat, |
| | type="messages", |
| | additional_inputs=[ |
| | article_state, analysis_state, login_btn |
| | ], |
| | ) |
| | chat.chatbot.height = 400 |
| |
|
| | with gr.Tab("Image"): |
| | img_prompt = gr.Textbox(label="Prompt", lines=3) |
| | width_slider = gr.Slider(256, 768, value=IMG_WIDTH, step=64, label="Width") |
| | height_slider = gr.Slider(256, 768, value=IMG_HEIGHT, step=64, label="Height") |
| | gen_btn = gr.Button("Generate Image", variant="primary") |
| | image_out = gr.Image(label="Result", type="pil") |
| | gen_status = gr.Markdown() |
| | gen_btn.click( |
| | generate_image, |
| | inputs=[img_prompt, width_slider, height_slider, login_btn], |
| | outputs=[image_out, gen_status] |
| | ) |
| |
|
| | def _analyze_all(text): |
| | t0 = time.time() |
| | summ = run_summary(text) |
| | s_lbl, s_score, c_lbl, c_score = run_text_nlp(text) |
| | ner_rows = run_ner_rows(text) |
| | dt = time.time() - t0 |
| | analysis = { |
| | "summary": summ, |
| | "sentiment": s_lbl, |
| | "sentiment_score": s_score, |
| | "category": c_lbl, |
| | "category_score": c_score, |
| | "entities": ner_rows, |
| | } |
| | return ( |
| | f"Processed in **{dt:.2f}s**.", |
| | summ, s_lbl, s_score, c_lbl, c_score, f"Updated at {time.strftime('%H:%M:%S')}", |
| | ner_rows, f"Extracted {len(ner_rows)} entities.", |
| | text, analysis |
| | ) |
| |
|
| | |
| | analyze_btn.click(lambda: gr.update(value="Analyzing...", interactive=False), [], [analyze_btn]) \ |
| | .then(_analyze_all, inputs=[txt_in], |
| | outputs=[run_status, summary_box, sent_lbl, sent_score, fin_lbl, fin_score, |
| | ta_status, ner_out, ner_status, article_state, analysis_state]) \ |
| | .then(lambda: gr.update(value="Analyze", interactive=True), [], [analyze_btn]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |